Commit e4387940 by Tianqi Chen Committed by GitHub

[BUILD] Windows build pass on LLVM/CUDA/OPENCL (#57)

parent 33310206
...@@ -92,6 +92,6 @@ ENV/ ...@@ -92,6 +92,6 @@ ENV/
*~ *~
build build
config.mk config.mk
build_win build_*
Win32 Win32
*.dir *.dir
cmake_minimum_required(VERSION 3.5) cmake_minimum_required(VERSION 3.5)
project(tvm) project(tvm C CXX)
include(cmake/Util.cmake) if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
include(${CMAKE_CURRENT_SOURCE_DIR}/build/private/local_config.cmake)
endif()
option(USE_OPENCL "Build with OpenCL" OFF) include(cmake/Util.cmake)
option(USE_CUDA "Build with CUDA" OFF) tvm_option(USE_CUDA "Build with CUDA" ON)
option(USE_LLVM "Build with LLVM" OFF) tvm_option(USE_OPENCL "Build with OpenCL" ON)
option(USE_RTTI "Build with RTTI" OFF) tvm_option(USE_LLVM "Build with LLVM" OFF)
tvm_option(USE_RTTI "Build with RTTI" OFF)
tvm_option(USE_MSVC_MT "Build with MT" OFF)
# include path
include_directories("include") include_directories("include")
include_directories("HalideIR/src") include_directories("HalideIR/src")
set(TVM_LINKER_LIBS "") set(TVM_LINKER_LIBS "")
...@@ -20,24 +23,22 @@ if(MSVC) ...@@ -20,24 +23,22 @@ if(MSVC)
add_definitions(-D_CRT_SECURE_NO_WARNINGS) add_definitions(-D_CRT_SECURE_NO_WARNINGS)
add_definitions(-D_SCL_SECURE_NO_WARNINGS) add_definitions(-D_SCL_SECURE_NO_WARNINGS)
add_definitions(-DTVM_EXPORTS) add_definitions(-DTVM_EXPORTS)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /MP") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /bigobj")
if(USE_MSVC_MT)
foreach(flag_var
CMAKE_CXX_FLAGS CMAKE_CXX_FLAGS_DEBUG CMAKE_CXX_FLAGS_RELEASE
CMAKE_CXX_FLAGS_MINSIZEREL CMAKE_CXX_FLAGS_RELWITHDEBINFO)
if(${flag_var} MATCHES "/MD")
string(REGEX REPLACE "/MD" "/MT" ${flag_var} "${${flag_var}}")
endif(${flag_var} MATCHES "/MD")
endforeach(flag_var)
endif()
else(MSVC) else(MSVC)
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11) check_cxx_compiler_flag("-std=c++11" SUPPORT_CXX11)
check_cxx_compiler_flag("-msse2" SUPPORT_MSSE2)
set(CMAKE_C_FLAGS "-O3 -fno-rtti -Wall -std=c++11 -fPIC") set(CMAKE_C_FLAGS "-O3 -fno-rtti -Wall -std=c++11 -fPIC")
if(SUPPORT_OPENMP)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fopenmp")
endif()
set(CMAKE_CXX_FLAGS ${CMAKE_C_FLAGS}) set(CMAKE_CXX_FLAGS ${CMAKE_C_FLAGS})
endif(MSVC) endif(MSVC)
...@@ -49,6 +50,7 @@ tvm_source_group("Source\\arithmetic" GLOB "src/arithmetic/*.cc") ...@@ -49,6 +50,7 @@ tvm_source_group("Source\\arithmetic" GLOB "src/arithmetic/*.cc")
tvm_source_group("Source\\schedule" GLOB "src/schedule/*.cc") tvm_source_group("Source\\schedule" GLOB "src/schedule/*.cc")
tvm_source_group("Source\\codegen" GLOB "src/codegen/*.cc") tvm_source_group("Source\\codegen" GLOB "src/codegen/*.cc")
tvm_source_group("Source\\codegen\\llvm" GLOB "src/codegen/llvm/*.cc") tvm_source_group("Source\\codegen\\llvm" GLOB "src/codegen/llvm/*.cc")
tvm_source_group("Source\\codegen\\stack_vm" GLOB "src/codegen/stack_vm/*.cc")
tvm_source_group("Source\\pass" GLOB "src/pass/*.cc") tvm_source_group("Source\\pass" GLOB "src/pass/*.cc")
tvm_source_group("Source\\runtime" GLOB "src/runtime/*.cc") tvm_source_group("Source\\runtime" GLOB "src/runtime/*.cc")
tvm_source_group("Source\\runtime\\cuda" GLOB "src/runtime/cuda/*.cc") tvm_source_group("Source\\runtime\\cuda" GLOB "src/runtime/cuda/*.cc")
...@@ -58,7 +60,7 @@ file(GLOB COMPILER_SRCS ...@@ -58,7 +60,7 @@ file(GLOB COMPILER_SRCS
src/api/*.cc src/api/*.cc
src/arithmetic/*.cc src/arithmetic/*.cc
src/codegen/*.cc src/codegen/*.cc
src/stack_vm/*.cc src/codegen/stack_vm/*.cc
src/lang/*.cc src/lang/*.cc
src/pass/*.cc src/pass/*.cc
src/schedule/*.cc src/schedule/*.cc
...@@ -71,19 +73,44 @@ file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc) ...@@ -71,19 +73,44 @@ file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) file(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
if(USE_CUDA) if(USE_CUDA)
find_package(CUDA)
find_package(CUDA QUIET REQUIRED)
message(STATUS "Build with CUDA support...")
include_directories(${CUDA_INCLUDE_DIRS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDA_LIBRARY})
list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS}) list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS})
if(MSVC)
find_library(CUDA_NVRTC_LIB nvrtc
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
${CUDA_TOOLKIT_ROOT_DIR}/lib/win32)
list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIB})
endif()
add_definitions(-DTVM_CUDA_RUNTIME=1)
else(USE_CUDA) else(USE_CUDA)
add_definitions(-DTVM_CUDA_RUNTIME=0) add_definitions(-DTVM_CUDA_RUNTIME=0)
endif(USE_CUDA) endif(USE_CUDA)
if(USE_OPENCL) if(USE_OPENCL)
find_package(OPENCL QUIET REQUIRED)
message(STATUS "Build with OpenCL support...")
include_directories(${OPENCL_INCLUDE_DIRS})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES})
list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS}) list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS})
add_definitions(-DTVM_OPENCL_RUNTIME=1)
else(USE_OPENCL) else(USE_OPENCL)
add_definitions(-DTVM_OPENCL_RUNTIME=0) add_definitions(-DTVM_OPENCL_RUNTIME=0)
endif(USE_OPENCL) endif(USE_OPENCL)
if(USE_LLVM) if(USE_LLVM)
add_definitions(-DTVM_LLVM_VERSION=40) find_package(LLVM REQUIRED CONFIG)
message(STATUS "Build with LLVM support...")
include_directories(${LLVM_INCLUDE_DIRS})
add_definitions(${LLVM_DEFINITIONS})
llvm_map_components_to_libnames(LLVM_LIBS all)
list(REMOVE_ITEM LLVM_LIBS LTO)
list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS})
add_definitions(-DTVM_LLVM_VERSION=${LLVM_PACKAGE_VERSION})
list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS}) list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS})
endif(USE_LLVM) endif(USE_LLVM)
...@@ -109,9 +136,7 @@ else() ...@@ -109,9 +136,7 @@ else()
set(CMAKE_SHARED_LIBRARY_PREFIX "") set(CMAKE_SHARED_LIBRARY_PREFIX "")
endif() endif()
add_library(libtvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(libtvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(libtvm_runtime SHARED ${RUNTIME_SRCS}) add_library(libtvm_runtime SHARED ${RUNTIME_SRCS})
target_link_libraries(libtvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS})
target_link_libraries(libtvm ${TVM_LINKER_LIBS})
target_link_libraries(libtvm_runtime ${TVM_RUNTIME_LINKER_LIBS}) target_link_libraries(libtvm_runtime ${TVM_RUNTIME_LINKER_LIBS})
...@@ -12,3 +12,44 @@ function(tvm_source_group group) ...@@ -12,3 +12,44 @@ function(tvm_source_group group)
source_group(${group} FILES ${srcs2}) source_group(${group} FILES ${srcs2})
endif() endif()
endfunction() endfunction()
#######################################################
# An option that the user can select. Can accept condition to control when option is available for user.
# Usage:
# tvm_option(<option_variable> "doc string" <initial value or boolean expression> [IF <condition>])
function(tvm_option variable description value)
set(__value ${value})
set(__condition "")
set(__varname "__value")
foreach(arg ${ARGN})
if(arg STREQUAL "IF" OR arg STREQUAL "if")
set(__varname "__condition")
else()
list(APPEND ${__varname} ${arg})
endif()
endforeach()
unset(__varname)
if("${__condition}" STREQUAL "")
set(__condition 2 GREATER 1)
endif()
if(${__condition})
if("${__value}" MATCHES ";")
if(${__value})
option(${variable} "${description}" ON)
else()
option(${variable} "${description}" OFF)
endif()
elseif(DEFINED ${__value})
if(${__value})
option(${variable} "${description}" ON)
else()
option(${variable} "${description}" OFF)
endif()
else()
option(${variable} "${description}" ${__value})
endif()
else()
unset(${variable} CACHE)
endif()
endfunction()
\ No newline at end of file
...@@ -46,7 +46,8 @@ def compile_source(code, target="ptx", arch=None, ...@@ -46,7 +46,8 @@ def compile_source(code, target="ptx", arch=None,
file_target = path_target if path_target else temp_target file_target = path_target if path_target else temp_target
cmd = ["nvcc"] cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"] cmd += ["--%s" % target, "-O3"]
cmd += ["-arch", arch] if arch:
cmd += ["-arch", arch]
cmd += ["-o", file_target] cmd += ["-o", file_target]
if options: if options:
......
...@@ -44,7 +44,7 @@ def find_lib_path(): ...@@ -44,7 +44,7 @@ def find_lib_path():
raise RuntimeError('Cannot find the files.\n' + raise RuntimeError('Cannot find the files.\n' +
'List of candidates:\n' + str('\n'.join(dll_path))) 'List of candidates:\n' + str('\n'.join(dll_path)))
if use_runtime: if use_runtime:
sys.stderr.write("Loading runtime library... this is execution only\n") sys.stderr.write("Loading runtime library %s... exec only\n" % lib_found[0])
sys.stderr.flush() sys.stderr.flush()
return lib_found return lib_found
......
...@@ -32,7 +32,7 @@ class Module(ModuleBase): ...@@ -32,7 +32,7 @@ class Module(ModuleBase):
modules : list of Modules modules : list of Modules
The module The module
""" """
nmod = ImportsSize(self) nmod = _ImportsSize(self)
return [_GetImport(self, i) for i in range(nmod)] return [_GetImport(self, i) for i in range(nmod)]
def save(self, file_name, fmt=""): def save(self, file_name, fmt=""):
......
...@@ -39,7 +39,6 @@ ...@@ -39,7 +39,6 @@
#include <utility> #include <utility>
#include <string> #include <string>
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#ifdef TVM_LLVM_VERSION #ifdef TVM_LLVM_VERSION
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h> #include <tvm/codegen.h>
#include <mutex>
#include "./llvm_common.h" #include "./llvm_common.h"
#include "./codegen_llvm.h" #include "./codegen_llvm.h"
#include "../../runtime/file_util.h" #include "../../runtime/file_util.h"
......
...@@ -417,7 +417,7 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) ...@@ -417,7 +417,7 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<LetStmt>([](const LetStmt *op, CodeGenStackVM* p) { .set_dispatch<LetStmt>([](const LetStmt *op, CodeGenStackVM* p) {
p->Push(op->value); p->Push(op->value);
int64_t vid = p->AllocVarID(op->var.get()); int64_t vid = p->AllocVarID(op->var.get());
p->PushOp(StackVM::STORE_HEAP, vid); p->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
p->Push(op->body); p->Push(op->body);
}) })
.set_dispatch<Ramp>([](const Ramp *op, CodeGenStackVM* p) { .set_dispatch<Ramp>([](const Ramp *op, CodeGenStackVM* p) {
...@@ -445,7 +445,7 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable) ...@@ -445,7 +445,7 @@ TVM_STATIC_IR_FUNCTOR(CodeGenStackVM, vtable)
.set_dispatch<Let>([](const Let *op, CodeGenStackVM* p) { .set_dispatch<Let>([](const Let *op, CodeGenStackVM* p) {
p->Push(op->value); p->Push(op->value);
int64_t vid = p->AllocVarID(op->var.get()); int64_t vid = p->AllocVarID(op->var.get());
p->PushOp(StackVM::STORE_HEAP, vid); p->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
p->Push(op->body); p->Push(op->body);
}) })
.set_dispatch<Load>([](const Load *op, CodeGenStackVM* p) { .set_dispatch<Load>([](const Load *op, CodeGenStackVM* p) {
......
...@@ -125,7 +125,7 @@ inline bool prove_equal(Expr lhs, Expr rhs) { ...@@ -125,7 +125,7 @@ inline bool prove_equal(Expr lhs, Expr rhs) {
} }
int ScanOpNode::num_outputs() const { int ScanOpNode::num_outputs() const {
return update.size(); return static_cast<int>(update.size());
} }
Array<IterVar> ScanOpNode::root_iter_vars() const { Array<IterVar> ScanOpNode::root_iter_vars() const {
return Array<IterVar>{scan_axis}; return Array<IterVar>{scan_axis};
......
...@@ -103,7 +103,7 @@ LoweredFunc MakeAPI(Stmt body, ...@@ -103,7 +103,7 @@ LoweredFunc MakeAPI(Stmt body,
MakeAssertEQ(v_num_packed_args, num_packed_args, os.str())); MakeAssertEQ(v_num_packed_args, num_packed_args, os.str()));
} }
for (size_t i = 0; i < api_args.size(); ++i) { for (int i = 0; i < static_cast<int>(api_args.size()); ++i) {
Var v_arg = f_arg_decl(i); Var v_arg = f_arg_decl(i);
if (i < static_cast<size_t>(num_packed_args)) { if (i < static_cast<size_t>(num_packed_args)) {
seq_init.emplace_back(LetStmt::make( seq_init.emplace_back(LetStmt::make(
......
...@@ -89,7 +89,7 @@ struct TVMRuntimeEntry { ...@@ -89,7 +89,7 @@ struct TVMRuntimeEntry {
if (val != nullptr) { if (val != nullptr) {
num_par_threads = atoi(val); num_par_threads = atoi(val);
} else { } else {
num_par_threads = std::thread::hardware_concurrency(); num_par_threads = std::thread::hardware_concurrency() / 2;
} }
} }
}; };
...@@ -127,7 +127,7 @@ int TVMModGetFunction(TVMModuleHandle mod, ...@@ -127,7 +127,7 @@ int TVMModGetFunction(TVMModuleHandle mod,
TVMFunctionHandle *func) { TVMFunctionHandle *func) {
API_BEGIN(); API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction( PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports); func_name, query_imports != 0);
if (pf != nullptr) { if (pf != nullptr) {
*func = new PackedFunc(pf); *func = new PackedFunc(pf);
} else { } else {
......
...@@ -39,7 +39,7 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -39,7 +39,7 @@ class CUDAModuleNode : public runtime::ModuleNode {
~CUDAModuleNode() { ~CUDAModuleNode() {
for (size_t i = 0; i < module_.size(); ++i) { for (size_t i = 0; i < module_.size(); ++i) {
if (module_[i] != nullptr) { if (module_[i] != nullptr) {
CUDA_CALL(cudaSetDevice(i)); CUDA_CALL(cudaSetDevice(static_cast<int>(i)));
CUDA_DRIVER_CALL(cuModuleUnload(module_[i])); CUDA_DRIVER_CALL(cuModuleUnload(module_[i]));
} }
} }
......
...@@ -75,11 +75,13 @@ class DSOModuleNode : public ModuleNode { ...@@ -75,11 +75,13 @@ class DSOModuleNode : public ModuleNode {
HMODULE lib_handle_{nullptr}; HMODULE lib_handle_{nullptr};
// Load the library // Load the library
void Load(const std::string& name) { void Load(const std::string& name) {
lib_handle_ = LoadLibrary(name.c_str()); // use wstring version that is needed by LLVM.
std::wstring wname(name.begin(), name.end());
lib_handle_ = LoadLibraryW(wname.c_str());
} }
BackendPackedCFunc GetFuncPtr(const std::string& name) { BackendPackedCFunc GetFuncPtr(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>( return reinterpret_cast<BackendPackedCFunc>(
GetProcAddress(lib_handle_, name.c_str())); // NOLINT(*) GetProcAddress(lib_handle_, (LPCSTR)name.c_str())); // NOLINT(*)
} }
void* GetGlobalVPtr(const std::string& name) { void* GetGlobalVPtr(const std::string& name) {
return reinterpret_cast<void*>( return reinterpret_cast<void*>(
......
...@@ -119,9 +119,9 @@ TVM_REGISTER_GLOBAL(_module__GetImport) ...@@ -119,9 +119,9 @@ TVM_REGISTER_GLOBAL(_module__GetImport)
imports().at(args[1].operator int()); imports().at(args[1].operator int());
}); });
TVM_REGISTER_GLOBAL(_module__GetTyeKey) TVM_REGISTER_GLOBAL(_module__GetTypeKey)
.set_body([](TVMArgs args, TVMRetValue *ret) { .set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->type_key(); *ret = std::string(args[0].operator Module()->type_key());
}); });
TVM_REGISTER_GLOBAL(_module__LoadFromFile) TVM_REGISTER_GLOBAL(_module__LoadFromFile)
......
...@@ -389,7 +389,7 @@ void InferRootBound(const Stage& stage, ...@@ -389,7 +389,7 @@ void InferRootBound(const Stage& stage,
bool direct_consume_by_parent = false; bool direct_consume_by_parent = false;
for (int i = 0; i < stage->op->num_outputs(); ++i) { for (int i = 0; i < stage->op->num_outputs(); ++i) {
Tensor t = stage->op.output(i); Tensor t = stage->op.output(i);
tmap.emplace(t, TensorDom(t.ndim())); tmap.emplace(t, TensorDom(static_cast<int>(t.ndim())));
auto it = feed_graph.find(t); auto it = feed_graph.find(t);
if (it != feed_graph.end()) { if (it != feed_graph.end()) {
for (const Operation& op : it->second) { for (const Operation& op : it->second) {
......
...@@ -22,6 +22,9 @@ struct TensorDimKey { ...@@ -22,6 +22,9 @@ struct TensorDimKey {
TensorDimKey(const Tensor& t, int dim) TensorDimKey(const Tensor& t, int dim)
: f(t->op), value_index(t->value_index), dim(dim) { : f(t->op), value_index(t->value_index), dim(dim) {
} }
TensorDimKey(const Tensor& t, size_t dim)
: f(t->op), value_index(t->value_index), dim(static_cast<int>(dim)) {
}
inline bool operator==(const TensorDimKey& other) const { inline bool operator==(const TensorDimKey& other) const {
return f == other.f && return f == other.f &&
value_index == other.value_index && value_index == other.value_index &&
...@@ -183,7 +186,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -183,7 +186,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
const auto& init = op.as<ScanOpNode>()->init; const auto& init = op.as<ScanOpNode>()->init;
for (size_t i = 0; i < update.size(); ++i) { for (size_t i = 0; i < update.size(); ++i) {
Tensor t = op.output(i); Tensor t = op.output(i);
for (size_t k = 1; k < update[i]->shape.size(); ++k) { for (int k = 1; k < static_cast<int>(update[i]->shape.size()); ++k) {
reach[TensorDimKey(t, k)].emplace_back( reach[TensorDimKey(t, k)].emplace_back(
TensorDimKey(update[i], k)); TensorDimKey(update[i], k));
reach[TensorDimKey(t, k)].emplace_back( reach[TensorDimKey(t, k)].emplace_back(
...@@ -203,7 +206,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) { ...@@ -203,7 +206,7 @@ ReachGraph GetReachGraph(const Array<Operation>& ops) {
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
if (!bset.count(call->func.get())) return; if (!bset.count(call->func.get())) return;
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
TensorDimKey dkey(call, i); TensorDimKey dkey(call, static_cast<int>(i));
auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) { auto fpush = [&dkey, &vmap, &reach](const NodeRef& node) {
const Variable *v = node.as<Variable>(); const Variable *v = node.as<Variable>();
auto it = vmap.find(v); auto it = vmap.find(v);
...@@ -319,7 +322,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis( ...@@ -319,7 +322,7 @@ Map<IterVar, Expr> ScanFixPointAnalysis(
if (call != nullptr && call->func.defined()) { if (call != nullptr && call->func.defined()) {
for (size_t i = 0; i < call->args.size(); ++i) { for (size_t i = 0; i < call->args.size(); ++i) {
auto it = vmap.find(call->args[i].get()); auto it = vmap.find(call->args[i].get());
TensorDimKey src(call, i); TensorDimKey src(call, static_cast<int>(i));
if (it != vmap.end()) { if (it != vmap.end()) {
f_merge_key(it->second, src); f_merge_key(it->second, src);
} else { } else {
......
...@@ -264,7 +264,7 @@ Schedule::Schedule(Array<Operation> ops) { ...@@ -264,7 +264,7 @@ Schedule::Schedule(Array<Operation> ops) {
} }
for (Operation op : post_order) { for (Operation op : post_order) {
Stage stage(op); Stage stage(op);
stage->is_output = output_set.count(op); stage->is_output = output_set.count(op) != 0;
n->stages.push_back(stage); n->stages.push_back(stage);
n->stage_map.Set(op, stage); n->stage_map.Set(op, stage);
// mark scan updates. // mark scan updates.
......
...@@ -21,8 +21,10 @@ def test_add(): ...@@ -21,8 +21,10 @@ def test_add():
# one line to build the function. # one line to build the function.
def check_device(device, host="stackvm"): def check_device(device, host="stackvm"):
if not tvm.codegen.enabled(host): if not tvm.codegen.enabled(host):
print("skip because %s is not enabled.." % host)
return return
if not tvm.codegen.enabled(device): if not tvm.codegen.enabled(device):
print("skip because %s is not enabled.." % device)
return return
fadd = tvm.build(s, [A, B, C], fadd = tvm.build(s, [A, B, C],
device, host, device, host,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment