Commit 7afeab07 by Tianqi Chen Committed by GitHub

[BUILD] Enable path option for ROCM, CUDA, Vulkan, simplify optional build (#1270)

parent d0eb2d3d
......@@ -3,7 +3,11 @@ project(tvm C CXX)
# Utility functions
include(cmake/util/Util.cmake)
include(cmake/util/FindCUDA.cmake)
include(cmake/util/FindVulkan.cmake)
include(cmake/util/FindLLVM.cmake)
include(cmake/util/FindROCM.cmake)
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
......
......@@ -22,12 +22,22 @@
#---------------------------------------------
# Backend runtimes.
#---------------------------------------------
# whether enable CUDA during compile
# Whether enable CUDA during compile,
#
# Possible values:
# - ON: enable CUDA with cmake's auto search
# - OFF: disbale CUDA
# - /path/to/cuda: use specific path to cuda toolkit
set(USE_CUDA OFF)
# ROCM
# Whether enable ROCM runtime
#
# Possible values:
# - ON: enable ROCM with cmake's auto search
# - OFF: disbale ROCM
# - /path/to/rocm: use specific path to rocm
set(USE_ROCM OFF)
set(ROCM_PATH "/opt/rocm")
# Whether enable OpenCL runtime
set(USE_OPENCL OFF)
......@@ -36,6 +46,11 @@ set(USE_OPENCL OFF)
set(USE_METAL OFF)
# Whether enable Vulkan runtime
#
# Possible values:
# - ON: enable Vulkan with cmake's auto search
# - OFF: disbale vulkan
# - /path/to/vulkan-sdk: use specific path to vulkan-sdk
set(USE_VULKAN OFF)
# Whether enable OpenGL runtime
......@@ -54,9 +69,9 @@ set(USE_GRAPH_RUNTIME_DEBUG OFF)
# Requires LLVM version >= 4.0
#
# Possible values:
# - ON: enable llvm with cmake's find llvm
# - ON: enable llvm with cmake's find search
# - OFF: disbale llvm
# - /path/to/llvm-config enable specific LLVM when multiple llvm-dev is available.
# - /path/to/llvm-config: enable specific LLVM when multiple llvm-dev is available.
set(USE_LLVM OFF)
#---------------------------------------------
......
# CUDA Module
find_package(CUDA QUIET)
find_cuda(${USE_CUDA})
if(CUDA_FOUND)
# always set the includedir when cuda is available
......@@ -8,69 +8,33 @@ if(CUDA_FOUND)
endif(CUDA_FOUND)
if(USE_CUDA)
find_package(CUDA REQUIRED)
# Find CUDA doesn't find all the libraries we need, add the extra ones
find_library(CUDA_CUDA_LIBRARIES cuda
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs)
find_library(CUDA_NVRTC_LIBRARIES nvrtc
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs)
if(CUDA_CUDA_LIBRARIES)
set(CUDA_CUDA_LIBRARY ${CUDA_CUDA_LIBRARIES})
if(NOT CUDA_FOUND)
message(FATAL_ERROR "Cannot find CUDA, USE_CUDA=" ${USE_CUDA})
endif()
message(STATUS "Build with CUDA support")
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY})
list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS})
list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_on.cc)
if(MSVC)
find_library(CUDA_NVRTC_LIB nvrtc
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/win32)
list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIB})
else(MSVC)
find_library(CUDA_NVRTC_LIB nvrtc
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib)
list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIB})
endif(MSVC)
list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
if(USE_CUDNN)
message(STATUS "Build with cuDNN support")
file(GLOB CONTRIB_CUDNN_SRCS src/contrib/cudnn/*.cc)
list(APPEND RUNTIME_SRCS ${CONTRIB_CUDNN_SRCS})
if(MSVC)
find_library(CUDA_CUDNN_LIB cudnn
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/win32)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIB})
else(MSVC)
find_library(CUDA_CUDNN_LIB cudnn
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIB})
endif(MSVC)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDNN_LIBRARY})
endif(USE_CUDNN)
if(USE_CUBLAS)
message(STATUS "Build with cuBLAS support")
file(GLOB CONTRIB_CUBLAS_SRCS src/contrib/cublas/*.cc)
list(APPEND RUNTIME_SRCS ${CONTRIB_CUBLAS_SRCS})
if(MSVC)
find_library(CUDA_CUBLAS_LIB cublas
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/win32)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIB})
else(MSVC)
find_library(CUDA_CUBLAS_LIB cublas
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIB})
endif(MSVC)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUBLAS_LIBRARY})
endif(USE_CUBLAS)
else(USE_CUDA)
list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_off.cc)
endif(USE_CUDA)
......@@ -2,12 +2,7 @@
add_definitions(-DDMLC_USE_FOPEN64=0)
if(NOT USE_LLVM STREQUAL "OFF")
if(NOT USE_LLVM STREQUAL "ON")
set(LLVM_CONFIG "${USE_LLVM}")
else()
set(LLVM_CONFIG "")
endif()
find_llvm()
find_llvm(${USE_LLVM})
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION})
......
......@@ -5,7 +5,6 @@ if(USE_METAL)
file(GLOB RUNTIME_METAL_SRCS src/runtime/metal/*.mm)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${METAL_LIB} ${FOUNDATION_LIB})
list(APPEND RUNTIME_SRCS ${RUNTIME_METAL_SRCS})
list(APPEND COMPILER_SRCS src/codegen/opt/build_metal_on.cc)
if(USE_MPS)
file(GLOB MPS_CONTRIB_SRC src/contrib/mps/*.mm)
......
......@@ -13,7 +13,6 @@ if(USE_OPENCL)
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES})
list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS})
list(APPEND COMPILER_SRCS src/codegen/opt/build_opencl_on.cc)
else()
list(APPEND COMPILER_SRCS src/codegen/opt/build_opencl_off.cc)
endif(USE_OPENCL)
......@@ -13,7 +13,6 @@ if(USE_OPENGL)
file(GLOB RUNTIME_OPENGL_SRCS src/runtime/opengl/*.cc)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenGL_LIBRARIES} glfw)
list(APPEND RUNTIME_SRCS ${RUNTIME_OPENGL_SRCS})
list(APPEND COMPILER_SRCS src/codegen/opt/build_opengl_on.cc)
else(USE_OPENGL)
list(APPEND COMPILER_SRCS src/codegen/opt/build_opengl_off.cc)
endif(USE_OPENGL)
# ROCM Module
if(NOT ROCM_PATH STREQUAL "")
include_directories(${ROCM_PATH}/include)
set(ROCM_LIB_PATH ${ROCM_PATH}/lib)
else()
set(ROCM_LIB_PATH /lib)
endif()
find_rocm(${USE_ROCM})
if(ROCM_FOUND)
# always set the includedir
# avoid global retrigger of cmake
include_directories(${ROCM_INCLUDE_DIRS})
add_definitions(-D__HIP_PLATFORM_HCC__=1)
endif(ROCM_FOUND)
if(USE_ROCM)
if(NOT ROCM_FOUND)
message(FATAL_ERROR "Cannot find ROCM, USE_ROCM=" ${USE_ROCM})
endif()
message(STATUS "Build with ROCM support")
find_library(ROCM_LIBS hip_hcc ${ROCM_LIB_PATH})
file(GLOB RUNTIME_ROCM_SRCS src/runtime/rocm/*.cc)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_LIBS})
list(APPEND RUNTIME_SRCS ${RUNTIME_ROCM_SRCS})
add_definitions(-DTVM_ROCM_RUNTIME=1 -D__HIP_PLATFORM_HCC__=1)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_HIPHCC_LIBRARY})
if(USE_MIOPEN)
message(STATUS "Build with MIOpen support")
file(GLOB MIOPEN_CONTRIB_SRCS src/contrib/miopen/*.cc)
list(APPEND RUNTIME_SRCS ${MIOPEN_CONTRIB_SRCS})
find_library(MIOPEN_LIBS MIOpen ${ROCM_LIB_PATH})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${MIOPEN_LIBS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_MIOPEN_LIBRARY})
endif(USE_MIOPEN)
if(USE_ROCBLAS)
message(STATUS "Build with RocBLAS support")
file(GLOB ROCBLAS_CONTRIB_SRCS src/contrib/rocblas/*.cc)
list(APPEND RUNTIME_SRCS ${ROCBLAS_CONTRIB_SRCS})
find_library(ROCBLAS_LIBS rocblas ${ROCM_LIB_PATH})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCBLAS_LIBS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY})
endif(USE_ROCBLAS)
else(USE_ROCM)
add_definitions(-DTVM_ROCM_RUNTIME=0)
list(APPEND COMPILER_SRCS src/codegen/opt/build_rocm_off.cc)
endif(USE_ROCM)
# Be compatible with older version of CMake
if(NOT $ENV{VULKAN_SDK} STREQUAL "")
set(Vulkan_INCLUDE_DIRS $ENV{VULKAN_SDK}/include)
set(Vulkan_FOUND ON)
else()
find_package(Vulkan QUIET)
endif()
find_vulkan(${USE_VULKAN})
if(Vulkan_FOUND)
# always set the includedir when cuda is available
# always set the includedir
# avoid global retrigger of cmake
include_directories(${Vulkan_INCLUDE_DIRS})
endif(Vulkan_FOUND)
if(USE_VULKAN)
if(NOT $ENV{VULKAN_SDK} STREQUAL "")
find_library(Vulkan_LIBRARY vulkan $ENV{VULKAN_SDK}/lib)
else()
find_package(Vulkan REQUIRED)
if(NOT Vulkan_FOUND)
message(FATAL_ERROR "Cannot find Vulkan, USE_VULKAN=" ${USE_VULKAN})
endif()
message(STATUS "Build with VULKAN support")
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY})
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
get_filename_component(VULKAN_LIB_PATH ${Vulkan_LIBRARY} DIRECTORY)
find_library(SPIRV_TOOLS_LIB SPIRV-Tools
${VULKAN_LIB_PATH}/spirv-tools)
list(APPEND TVM_LINKER_LIBS ${SPIRV_TOOLS_LIB})
list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${Vulkan_LIBRARY})
endif(USE_VULKAN)
#######################################################
# Enhanced version of find CUDA.
#
# Usage:
# find_cuda(${USE_CUDA})
#
# - When USE_CUDA=ON, use auto search
# - When USE_CUDA=/path/to/cuda-path, use the cuda path
#
# Provide variables:
#
# - CUDA_FOUND
# - CUDA_INCLUDE_DIRS
# - CUDA_TOOLKIT_ROOT_DIR
# - CUDA_CUDA_LIBRARY
# - CUDA_CUDART_LIBRARY
# - CUDA_NVRTC_LIBRARY
# - CUDA_CUDNN_LIBRARY
# - CUDA_CUBLAS_LIBRARY
#
macro(find_cuda use_cuda)
set(__use_cuda ${use_cuda})
if(__use_cuda STREQUAL "ON")
find_package(CUDA QUIET)
elseif(IS_DIRECTORY ${__use_cuda})
set(CUDA_TOOLKIT_ROOT_DIR ${__use_cuda})
message(STATUS "Custom CUDA_PATH=" ${CUDA_TOOLKIT_ROOT_DIR})
set(CUDA_INCLUDE_DIRS ${CUDA_TOOLKIT_ROOT_DIR}/include)
set(CUDA_FOUND TRUE)
if(MSVC)
find_library(CUDA_CUDAT_LIBRARY cudart
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib)
else(MSVC)
find_library(CUDA_CUDAT_LIBRARY cudart
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib)
endif(MSVC)
endif()
# additional libraries
if(CUDA_FOUND)
if(MSVC)
find_library(CUDA_NVRTC_LIBRARY cuda
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/win32)
find_library(CUDA_NVRTC_LIBRARY nvrtc
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/win32)
find_library(CUDA_CUDNN_LIBRARY cudnn
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/win32)
else(MSVC)
find_library(_CUDA_CUDA_LIBRARY cuda
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs)
if(_CUDA_CUDA_LIBRARY)
set(CUDA_CUDA_LIBRARY ${_CUDA_CUDA_LIBRARY})
endif()
find_library(CUDA_NVRTC_LIBRARY nvrtc
PATHS ${CUDA_TOOLKIT_ROOT_DIR}
PATH_SUFFIXES lib lib64 targets/x86_64-linux/lib targets/x86_64-linux/lib/stubs)
find_library(CUDA_CUDNN_LIBRARY cudnn
${CUDA_TOOLKIT_ROOT_DIR}/lib64
${CUDA_TOOLKIT_ROOT_DIR}/lib)
endif(MSVC)
endif(CUDA_FOUND)
endmacro(find_cuda)
#######################################################
# Enhanced version of find llvm that allows set of LLVM_CONFIG
# When LLVM_CONFIG_PATH is AUTO,
# it defaults to system find llvm
# Enhanced version of find llvm.
#
# Usage:
# find_llvm(LLVM_CONFIG_PATH)
# find_llvm(${USE_LLVM})
#
# - When USE_LLVM=ON, use auto search
# - When USE_LLVM=/path/to/llvm-config, use corresponding config
#
# Provide variables:
# - LLVM_INCLUDE_DIRS
# - LLVM_LIBS
# - LLVM_DEFINITIONS
# - LLVM_VERSION_CONCAT
# - TVM_LLVM_VERISON
#
macro(find_llvm)
if(LLVM_CONFIG STREQUAL "")
macro(find_llvm use_llvm)
set(LLVM_CONFIG ${use_llvm})
if(LLVM_CONFIG STREQUAL "ON")
find_package(LLVM REQUIRED CONFIG)
llvm_map_components_to_libnames(LLVM_LIBS all)
list(REMOVE_ITEM LLVM_LIBS LTO)
set(TVM_LLVM_VERSION ${LLVM_VERSION_MAJOR}${LLVM_VERSION_MINOR})
else()
elseif(NOT LLVM_CONFIG STREQUAL "OFF")
# use llvm config
message(STATUS "Use llvm-config=" ${LLVM_CONFIG})
execute_process(COMMAND ${LLVM_CONFIG} --includedir
......
#######################################################
# Enhanced version of find rocm.
#
# Usage:
# find_rocm(${USE_ROCM})
#
# - When USE_VULKAN=ON, use auto search
# - When USE_VULKAN=/path/to/vulkan-sdk-path, use the sdk
#
# Provide variables:
#
# - ROCM_FOUND
# - ROCM_INCLUDE_DIRS
# - ROCM_HIPHCC_LIBRARY
# - ROCM_MIOPEN_LIBRARY
# - ROCM_ROCBLAS_LIBRARY
#
macro(find_rocm use_rocm)
set(__use_rocm ${use_rocm})
if(IS_DIRECTORY ${__use_rocm})
set(__rocm_sdk ${__use_rocm})
message(STATUS "Custom ROCM SDK PATH=" ${__use_rocm})
elseif(IS_DIRECTORY $ENV{ROCM_PATH})
set(__rocm_sdk $ENV{ROCM_PATH})
elseif(IS_DIRECTORY /opt/rocm)
set(__rocm_sdk /opt/rocm)
else()
set(__rocm_sdk "")
endif()
if(__rocm_sdk)
set(ROCM_INCLUDE_DIRS ${__rocm_sdk}/include)
find_library(ROCM_HIPHCC_LIBRARY hip_hcc ${__rocm_sdk}/lib)
find_library(ROCM_MIOPEN_LIBRARY MIOpen ${__rocm_sdk}/lib)
find_library(ROCM_ROCBLAS_LIBRARY rocblas ${__rocm_sdk}/lib)
if(ROCM_HIPHCC_LIBRARY)
set(ROCM_FOUND TRUE)
endif()
endif(__rocm_sdk)
endmacro(find_rocm)
#######################################################
# Enhanced version of find Vulkan.
#
# Usage:
# find_vulkan(${USE_VULKAN})
#
# - When USE_VULKAN=ON, use auto search
# - When USE_VULKAN=/path/to/vulkan-sdk-path, use the sdk
#
# Provide variables:
#
# - Vulkan_FOUND
# - Vulkan_INCLUDE_DIRS
# - Vulkan_LIBRARY
# - Vulkan_SPIRV_TOOLS_LIBRARY
#
macro(find_vulkan use_vulkan)
set(__use_vulkan ${use_vulkan})
if(IS_DIRECTORY ${__use_vulkan})
set(__vulkan_sdk ${__use_vulkan})
message(STATUS "Custom Vulkan SDK PATH=" ${__use_vulkan})
elseif(IS_DIRECTORY $ENV{VULKAN_SDK})
set(__vulkan_sdk $ENV{VULKAN_SDK})
else()
set(__vulkan_sdk "")
endif()
if(__vulkan_sdk)
set(Vulkan_INCLUDE_DIRS ${__vulkan_sdk}/include)
find_library(Vulkan_LIBRARY vulkan ${__vulkan_sdk}/lib)
if(Vulkan_LIBRARY)
set(Vulkan_FOUND TRUE)
endif()
endif(__vulkan_sdk)
# resort to find vulkan of option is on
if(NOT Vulkan_FOUND)
if(__use_vulkan STREQUAL "ON")
find_package(Vulkan QUIET)
endif()
endif()
# additional libraries
if(Vulkan_FOUND)
get_filename_component(VULKAN_LIBRARY_PATH ${Vulkan_LIBRARY} DIRECTORY)
find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools
${VULKAN_LIBRARY_PATH}/spirv-tools)
endif(Vulkan_FOUND)
endmacro(find_vulkan)
......@@ -6,6 +6,8 @@
#include <vector>
#include <string>
#include "./codegen_metal.h"
#include "./build_common.h"
#include "../runtime/metal/metal_module.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -220,5 +222,29 @@ void CodeGenMetal::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOLI
}
os << ')';
}
runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenMetal cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
std::string fmt = "metal";
std::string source = "";
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
source = code;
code = (*f)(code).operator std::string();
fmt = "metallib";
}
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source);
}
TVM_REGISTER_API("codegen.build_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildMetal(args[0]);
});
} // namespace codegen
} // namespace tvm
......@@ -6,7 +6,9 @@
#include <vector>
#include <string>
#include "./codegen_opencl.h"
#include "./build_common.h"
#include "../runtime/thread_storage_scope.h"
#include "../runtime/opencl/opencl_module.h"
namespace tvm {
namespace codegen {
......@@ -202,5 +204,26 @@ void CodeGenOpenCL::VisitExpr_(const Broadcast* op, std::ostream& os) { // NOL
}
os << "))";
}
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string();
}
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs));
}
TVM_REGISTER_API("codegen.build_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenCL(args[0]);
});
} // namespace codegen
} // namespace tvm
......@@ -9,6 +9,7 @@
#include <vector>
#include <string>
#include "./codegen_opengl.h"
#include "./build_common.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
......@@ -268,5 +269,21 @@ void CodeGenOpenGL::VisitStmt_(const Evaluate* op) {
this->stream << GetVarID(buffer) << " = " << PrintExpr(value) << ";\n";
}
runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
bool output_ssa = false;
CodeGenOpenGL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
auto shaders = cg.Finish();
return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs));
}
TVM_REGISTER_API("codegen.build_opengl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenGL(args[0]);
});
} // namespace codegen
} // namespace tvm
......@@ -12,10 +12,7 @@
#include "../build_common.h"
#include "../codegen_source_base.h"
#include "../../pass/ir_util.h"
#if TVM_ROCM_RUNTIME
#include "../../runtime/rocm/rocm_module.h"
#endif // TVM_ROCM_RUNTIME
namespace tvm {
namespace codegen {
......@@ -242,20 +239,7 @@ runtime::Module BuildAMDGPU(Array<LoweredFunc> funcs, std::string target) {
std::string hsaco = (*f)(arr);
std::string ll(data_ll.begin(), data_ll.end());
#if TVM_ROCM_RUNTIME
return ROCMModuleCreate(hsaco, "hsaco", ExtractFuncInfo(funcs), ll, assembly);
#else
LOG(WARNING) << "ROCM runtime is not enabled, return a source module...";
auto fget_source = [ll, assembly](const std::string& format) {
if (format.length() == 0) return assembly;
if (format == "ll" || format == "llvm") return format;
if (format == "asm") return assembly;
return std::string("");
};
return DeviceSourceModuleCreate(
hsaco, "hsaco", ExtractFuncInfo(funcs), "hsaco", fget_source);
#endif // TVM_ROCM_RUNTIME
}
TVM_REGISTER_API("codegen.build_rocm")
......
/*!
* Copyright (c) 2017 by Contributors
* Build metal modules from source.
* \file build_metal.h
*/
#ifndef TVM_CODEGEN_OPT_BUILD_METAL_H_
#define TVM_CODEGEN_OPT_BUILD_METAL_H_
#include <string>
#include "../codegen_metal.h"
#include "../build_common.h"
#if TVM_METAL_RUNTIME
#include "../../runtime/metal/metal_module.h"
#endif // TVM_METAL_RUNTIME
namespace tvm {
namespace codegen {
runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenMetal cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
#if TVM_METAL_RUNTIME
std::string fmt = "metal";
std::string source = "";
if (const auto* f = Registry::Get("tvm_callback_metal_compile")) {
source = code;
code = (*f)(code).operator std::string();
fmt = "metallib";
}
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source);
#else
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
return DeviceSourceModuleCreate(code, "metal", ExtractFuncInfo(funcs), "metal");
#endif // TVM_METAL_RUNTIME
}
TVM_REGISTER_API("codegen.build_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildMetal(args[0]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_OPT_BUILD_METAL_H_
/*!
* Copyright (c) 2018 by Contributors
* Build Metal modules off
* Optional module when build metal is switched to off
*/
#define TVM_METAL_RUNTIME 0
#include "./build_metal.h"
#include "../codegen_source_base.h"
#include "../../runtime/metal/metal_module.h"
namespace tvm {
namespace runtime {
Module MetalModuleCreate(std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string source) {
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "metal");
}
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* Build Metal modules on
*/
#define TVM_METAL_RUNTIME 1
#include "./build_metal.h"
/*!
* Copyright (c) 2017 by Contributors
* Build opencl modules from source.
* \file build_opencl.h
*/
#ifndef TVM_CODEGEN_OPT_BUILD_OPENCL_H_
#define TVM_CODEGEN_OPT_BUILD_OPENCL_H_
#include <tvm/base.h>
#include <string>
#include "../codegen_opencl.h"
#include "../build_common.h"
#if TVM_OPENCL_RUNTIME
#include "../../runtime/opencl/opencl_module.h"
#endif // TVM_OPENCL_RUNTIME
namespace tvm {
namespace codegen {
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_opencl_postproc")) {
code = (*f)(code).operator std::string();
}
#if TVM_OPENCL_RUNTIME
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs));
#else
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return DeviceSourceModuleCreate(code, "cl", ExtractFuncInfo(funcs), "opencl");
#endif // TVM_OPENCL_RUNTIME
}
TVM_REGISTER_API("codegen.build_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenCL(args[0]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_OPT_BUILD_OPENCL_H_
/*!
* Copyright (c) 2018 by Contributors
* Build opencl modules off
* Optional module when build opencl is switched to off
*/
#define TVM_OPENCL_RUNTIME 0
#include "./build_opencl.h"
#include "../codegen_source_base.h"
#include "../../runtime/opencl/opencl_module.h"
namespace tvm {
namespace runtime {
Module OpenCLModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) {
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl");
}
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* Build opencl modules on
*/
#define TVM_OPENCL_RUNTIME 1
#include "./build_opencl.h"
/*!
* Copyright (c) 2017 by Contributors
* Build opengl modules from source.
* \file build_opengl.h
*/
#ifndef TVM_CODEGEN_OPT_BUILD_OPENGL_H_
#define TVM_CODEGEN_OPT_BUILD_OPENGL_H_
#include <tvm/base.h>
#include "../codegen_opengl.h"
#include "../build_common.h"
namespace tvm {
namespace codegen {
runtime::Module BuildOpenGL(Array<LoweredFunc> funcs) {
bool output_ssa = false;
CodeGenOpenGL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
auto shaders = cg.Finish();
#if TVM_OPENGL_RUNTIME
return OpenGLModuleCreate(shaders, "gl", ExtractFuncInfo(funcs));
#else
LOG(WARNING) << "OpenGL runtime not enabled, return a source module...";
auto data = ToJSON(shaders);
return DeviceSourceModuleCreate(data, "gl", ExtractFuncInfo(funcs), "opengl");
#endif // TVM_OPENGL_RUNTIME
}
TVM_REGISTER_API("codegen.build_opengl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenGL(args[0]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_OPT_BUILD_OPENGL_H_
/*!
* Copyright (c) 2018 by Contributors
* Build OpenGL modules off
* Optional module when build opencl is switched to off
*/
#define TVM_OPENGL_RUNTIME 0
#include "./build_opengl.h"
#include "../codegen_source_base.h"
#include "../../runtime/opengl/opengl_module.h"
namespace tvm {
namespace runtime {
Module OpenGLModuleCreate(std::unordered_map<std::string, OpenGLShader> shaders,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap) {
LOG(WARNING) << "OpenGL runtime not enabled, return a source module...";
auto data = ToJSON(shaders);
return codegen::DeviceSourceModuleCreate(data, "gl", fmap, "opengl");
}
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2018 by Contributors
* Build OpenGL modules on
*/
#define TVM_OPENGL_RUNTIME 1
#include "./build_opengl.h"
/*!
* Copyright (c) 2018 by Contributors
* Optional module when build rocm is switched to off
*/
#include "../codegen_source_base.h"
#include "../../runtime/rocm/rocm_module.h"
namespace tvm {
namespace runtime {
Module ROCMModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string rocm_source,
std::string assembly) {
LOG(WARNING) << "ROCM runtime is not enabled, return a source module...";
auto fget_source = [rocm_source, assembly](const std::string& format) {
if (format.length() == 0) return assembly;
if (format == "ll" || format == "llvm") return rocm_source;
if (format == "asm") return assembly;
return std::string("");
};
return codegen::DeviceSourceModuleCreate(
data, fmt, fmap, "hsaco", fget_source);
}
} // namespace runtime
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment