Unverified Commit 33b0831c by Tianqi Chen Committed by GitHub

[REFACTOR][CODEGEN] codegen->target, build_module->driver (#4742)

This PR moves the codegen related code into the target folder,
as they are target specific functionalities.

We also adopt the term "compiler driver" in common compiler infra
such as rust, GHC and clang.
As a result, build_module is moved into the driver folder.
parent 992b5b54
......@@ -127,16 +127,17 @@ assign_source_group("Include" ${GROUP_INCLUDE})
file(GLOB_RECURSE COMPILER_SRCS
src/node/*.cc
src/ir/*.cc
src/target/*.cc
src/arith/*.cc
src/top/*.cc
src/api/*.cc
src/autotvm/*.cc
src/tir/*.cc
src/driver/*.cc
src/api/*.cc
)
file(GLOB CODEGEN_SRCS
src/codegen/*.cc
src/target/*.cc
src/target/source/*.cc
)
list(APPEND COMPILER_SRCS ${CODEGEN_SRCS})
......@@ -170,7 +171,7 @@ if(USE_VM_PROFILER)
list(APPEND COMPILER_SRCS ${BACKEND_VM_PROFILER_SRCS})
endif(USE_VM_PROFILER)
file(GLOB DATATYPE_SRCS src/codegen/datatype/*.cc)
file(GLOB DATATYPE_SRCS src/target/datatype/*.cc)
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
......@@ -197,7 +198,7 @@ if(USE_RPC)
endif(USE_RPC)
file(GLOB STACKVM_RUNTIME_SRCS src/runtime/stackvm/*.cc)
file(GLOB STACKVM_CODEGEN_SRCS src/codegen/stackvm/*.cc)
file(GLOB STACKVM_CODEGEN_SRCS src/target/stackvm/*.cc)
list(APPEND COMPILER_SRCS ${STACKVM_CODEGEN_SRCS})
if(USE_STACKVM_RUNTIME)
message(STATUS "Build with stackvm support in runtime...")
......
......@@ -31,7 +31,7 @@ if(USE_CUDA)
message(STATUS "Build with CUDA support")
file(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_CUDA_SRCS})
list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_on.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_on.cc)
list(APPEND TVM_LINKER_LIBS ${CUDA_NVRTC_LIBRARY})
list(APPEND TVM_RUNTIME_LINKER_LIBS ${CUDA_CUDART_LIBRARY})
......@@ -53,5 +53,5 @@ if(USE_CUDA)
endif(USE_CUBLAS)
else(USE_CUDA)
list(APPEND COMPILER_SRCS src/codegen/opt/build_cuda_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_cuda_off.cc)
endif(USE_CUDA)
......@@ -26,7 +26,7 @@ if(NOT USE_LLVM STREQUAL "OFF")
message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION})
# Set flags that are only needed for LLVM target
add_definitions(-DTVM_LLVM_VERSION=${TVM_LLVM_VERSION})
file(GLOB COMPILER_LLVM_SRCS src/codegen/llvm/*.cc)
file(GLOB COMPILER_LLVM_SRCS src/target/llvm/*.cc)
list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS})
list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS})
if(NOT MSVC)
......
......@@ -30,5 +30,5 @@ if(USE_METAL)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${MPS_CONTRIB_LIB})
endif()
else(USE_METAL)
list(APPEND COMPILER_SRCS src/codegen/opt/build_metal_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_metal_off.cc)
endif(USE_METAL)
......@@ -33,7 +33,7 @@ if(USE_SDACCEL)
set(USE_OPENCL ON)
endif()
else()
list(APPEND COMPILER_SRCS src/codegen/opt/build_sdaccel_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_sdaccel_off.cc)
endif(USE_SDACCEL)
if(USE_AOCL)
......@@ -45,7 +45,7 @@ if(USE_AOCL)
set(USE_OPENCL ON)
endif()
else()
list(APPEND COMPILER_SRCS src/codegen/opt/build_aocl_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_aocl_off.cc)
endif(USE_AOCL)
if(USE_OPENCL)
......@@ -55,5 +55,5 @@ if(USE_OPENCL)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES})
list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS})
else()
list(APPEND COMPILER_SRCS src/codegen/opt/build_opencl_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_opencl_off.cc)
endif(USE_OPENCL)
......@@ -31,5 +31,5 @@ if(USE_OPENGL)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenGL_LIBRARIES} glfw)
list(APPEND RUNTIME_SRCS ${RUNTIME_OPENGL_SRCS})
else(USE_OPENGL)
list(APPEND COMPILER_SRCS src/codegen/opt/build_opengl_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_opengl_off.cc)
endif(USE_OPENGL)
......@@ -49,5 +49,5 @@ if(USE_ROCM)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${ROCM_ROCBLAS_LIBRARY})
endif(USE_ROCBLAS)
else(USE_ROCM)
list(APPEND COMPILER_SRCS src/codegen/opt/build_rocm_off.cc)
list(APPEND COMPILER_SRCS src/target/opt/build_rocm_off.cc)
endif(USE_ROCM)
......@@ -38,7 +38,7 @@ if(USE_VULKAN)
endif()
message(STATUS "Build with Vulkan support")
file(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/vulkan.cc)
file(GLOB COMPILER_VULKAN_SRCS src/codegen/spirv/*.cc)
file(GLOB COMPILER_VULKAN_SRCS src/target/spirv/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
......
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/driver/driver.h
* \brief Compiler driver utilities.
*
* This module provides end-to-end utils to drive the compilation process.
* We adopt the term "compiler driver" in common compiler infrastructures.
* Note that a compiler driver is different from "runtime drivers".
* Most of runtime related code are defined in the runtime folder instead.
*/
#ifndef TVM_DRIVER_DRIVER_H_
#define TVM_DRIVER_DRIVER_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <tvm/support/with.h>
#include <tvm/top/schedule_pass.h>
#include <tvm/tir/lowered_func.h>
#include <string>
#include <vector>
#include <utility>
#include <unordered_map>
#include <unordered_set>
namespace tvm {
/*!
* \brief Build a LoweredFunc given a schedule, args and binds
* \param sch The schedule to lower.
* \param args The arguments to the function.
* \param name The name of the lowered function.
* \param binds Buffer assignments.
* \param config The build configuration.
* \return The lowered function.
*/
TVM_DLL Array<tir::LoweredFunc> lower(
top::Schedule sch,
const Array<top::Tensor>& args,
const std::string& name,
const std::unordered_map<top::Tensor, tir::Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<tir::LoweredFunc> > split_dev_host_funcs(
const Array<tir::LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The built module.
*/
TVM_DLL runtime::Module build(const Array<tir::LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* for heterogeneous build.
* \param input The map contains target to a list of lowered functions pairs.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<Target, Array<tir::LoweredFunc>>& input,
const Target& target_host,
const BuildConfig& config);
/*!
* \brief Build a device and host module for a specific target from a map
* contains target to a list of lowered functions pairs. This function is used
* for heterogeneous build.
* \param input The map contains target string to a list of lowered functions
* pairs.
* \param target_host The target for building host code. To use the default,
* pass Target().
* \param config The build configuration.
* \return The built module that contains code for different processors.
*/
TVM_DLL runtime::Module build(const Map<std::string, Array<tir::LoweredFunc>>& input,
const Target& target_host,
const BuildConfig& config);
} // namespace tvm
#endif // TVM_DRIVER_DRIVER_H_
......@@ -34,12 +34,13 @@
#ifndef TVM_RELAY_INTERPRETER_H_
#define TVM_RELAY_INTERPRETER_H_
#include <tvm/build_module.h>
#include <tvm/ir/module.h>
#include <tvm/relay/expr.h>
#include <tvm/runtime/object.h>
#include <tvm/runtime/container.h>
#include <tvm/runtime/vm.h>
#include <tvm/target/target.h>
namespace tvm {
namespace relay {
......
......@@ -26,9 +26,9 @@
#include <tvm/top/tensor.h>
#include <tvm/top/schedule.h>
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <tvm/target/target.h>
#include <tvm/tir/data_layout.h>
#include <string>
......
......@@ -18,20 +18,22 @@
*/
/*!
* \file tvm/codegen.h
* \brief Collection of Lowlevel IR pass to codegen.
* \file tvm/target/codegen.h
* \brief Translates IRModule to runtime::Module.
*/
#ifndef TVM_CODEGEN_H_
#define TVM_CODEGEN_H_
#ifndef TVM_TARGET_CODEGEN_H_
#define TVM_TARGET_CODEGEN_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/target/target.h>
#include <string>
namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
/*! \brief namespace for target translation and codegen. */
namespace codegen {
// use packed function from runtime.
using runtime::PackedFunc;
......@@ -76,5 +78,4 @@ runtime::Module PackImportsToLLVM(const runtime::Module& m,
const std::string& target_triple);
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_H_
#endif // TVM_TARGET_CODEGEN_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/target/generic_func.h
* \brief Generic function that can be specialzied on a per target basis.
*/
#ifndef TVM_TARGET_GENERIC_FUNC_H_
#define TVM_TARGET_GENERIC_FUNC_H_
#include <tvm/support/with.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/target/target.h>
#include <vector>
#include <string>
#include <utility>
#include <unordered_map>
namespace tvm {
class GenericFuncNode;
/*!
* \brief Generic function that can be specialized on a per-target basis.
*/
class GenericFunc : public ObjectRef {
public:
GenericFunc() {}
explicit GenericFunc(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Set the default function implementaiton.
* \param value The default function
* \param allow_override If true, this call may override a previously registered function. If
* false, an error will be logged if the call would override a previously registered function.
* \return reference to self.
*/
TVM_DLL GenericFunc& set_default(const runtime::PackedFunc value,
bool allow_override = false);
/*!
* \brief Register a specialized function
* \param tags The tags for this specialization
* \param value The specialized function
* \param allow_override If true, this call may override previously registered tags. If false,
* an error will be logged if the call would override previously registered tags.
* \return reference to self.
*/
TVM_DLL GenericFunc& register_func(const std::vector<std::string>& tags,
const runtime::PackedFunc value,
bool allow_override = false);
/*!
* \brief Call generic function by directly passing in unpacked format.
* \param args Arguments to be passed.
* \tparam Args arguments to be passed.
*
* \code
* // Example code on how to call generic function
* void CallGeneirc(GenericFunc f) {
* // call like normal functions by pass in arguments
* // return value is automatically converted back
* int rvalue = f(1, 2.0);
* }
* \endcode
*/
template<typename... Args>
inline runtime::TVMRetValue operator()(Args&& ...args) const;
/*!
* \brief Invoke the relevant function for the current target context, set by set_target_context.
* Arguments are passed in packed format.
* \param args The arguments to pass to the function.
* \param ret The return value
*/
TVM_DLL void CallPacked(runtime::TVMArgs args,
runtime::TVMRetValue* ret) const;
/*!
* \brief Find or register the GenericFunc instance corresponding to the give name
* \param name The name of the registered GenericFunc
* \return The GenericFunc instance
*/
TVM_DLL static GenericFunc Get(const std::string& name);
/*!
* \brief Add a GenericFunc instance to the registry
* \param func The GenericFunc instance
* \param name The name of the registered GenericFunc
*/
TVM_DLL static void RegisterGenericFunc(GenericFunc func, const std::string& name);
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline GenericFuncNode* operator->();
// declare container type
using ContainerType = GenericFuncNode;
// Internal class.
struct Manager;
private:
friend struct Manager;
};
template<typename... Args>
inline runtime::TVMRetValue GenericFunc::operator()(Args&& ...args) const {
const int kNumArgs = sizeof...(Args);
const int kArraySize = kNumArgs > 0 ? kNumArgs : 1;
TVMValue values[kArraySize];
int type_codes[kArraySize];
runtime::detail::for_each(runtime::TVMArgsSetter(values, type_codes),
std::forward<Args>(args)...);
runtime::TVMRetValue rv;
CallPacked(runtime::TVMArgs(values, type_codes, kNumArgs), &rv);
return rv;
}
/*!
* \brief Represents a generic function that can be specialized on a per-target basis.
*/
class GenericFuncNode : public Object {
public:
/*! \brief name of the function */
std::string name_;
/* \brief the generic builder */
runtime::PackedFunc generic_func_;
/* \brief map from keys to registered functions */
std::unordered_map<std::string, runtime::PackedFunc> dispatch_dict_;
void VisitAttrs(AttrVisitor* v) {}
static constexpr const char* _type_key = "GenericFunc";
TVM_DECLARE_FINAL_OBJECT_INFO(GenericFuncNode, Object);
};
inline GenericFuncNode* GenericFunc::operator->() {
return static_cast<GenericFuncNode*>(get_mutable());
}
#define TVM_GENERIC_FUNC_REG_VAR_DEF \
static TVM_ATTRIBUTE_UNUSED ::tvm::GenericFunc& __mk_ ## TVM
/*!
* \def TVM_REGISTER_GENERIC_FUNC
* \brief Register a new generic function, or set a device-specific variant
* of the corresponding function.
*
* \param name The name of the function
*/
#define TVM_REGISTER_GENERIC_FUNC(name) \
TVM_STR_CONCAT(TVM_GENERIC_FUNC_REG_VAR_DEF, __COUNTER__) = \
::tvm::GenericFunc::Get(#name)
} // namespace tvm
#endif // TVM_TARGET_GENERIC_FUNC_H_
......@@ -31,6 +31,7 @@
#include <string>
#include <vector>
#include <unordered_set>
#include <utility>
namespace tvm {
/*!
......@@ -177,5 +178,131 @@ TVM_DLL Target stackvm(const std::vector<std::string>& options =
TVM_DLL Target ext_dev(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target
/*!
* \brief Container for build configuration options
*/
class BuildConfigNode : public Object {
public:
/*!
* \brief The data alignment to use when constructing buffers. If this is set to
* -1, then TVM's internal default will be used
*/
int data_alignment = -1;
/*!
* \brief The offset factor to use when constructing buffers. If this is set to
* 0, then the offset field is not used.
*/
int offset_factor = 0;
/*!
* \brief Splitting factor for loop splitting. If this is set to zero, no splitting will be
* done. Otherwise, a split will be done with this factor and the inner loop will be unrolled.
*/
int double_buffer_split_loop = 1;
/*! \brief Threshold of number of steps in the loop to be automatically unrolled */
int auto_unroll_max_step = 0;
/*! \brief The maximum nested level of loops that can be automatically unrolled */
int auto_unroll_max_depth = 8;
/*! \brief The maximum extent of loop that will be unrolled */
int auto_unroll_max_extent = 0;
/*!
* \brief Whether to explicitly unroll the loop. If set to false, the unroll hint will
* be passed to the CodeGen phase. Set to true if CodeGen supports unroll pragma.
*/
bool unroll_explicit = true;
/*! \brief Set to true if buffer arguments do not overlap. This enables more optimization. */
bool restricted_func = true;
/*! \brief Whether to detect global barrier */
bool detect_global_barrier = false;
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
/*! \brief Whether to instrument loads and stores with check for out of the bounds. */
bool instrument_bound_checkers = false;
/*! \brief Whether to disable select rewriting. */
bool disable_select_rewriting = false;
/*! \brief Whether to disable loop vectorization. */
bool disable_vectorize = false;
/*! \brief Whether to disable assert stmt generation. */
bool disable_assert = false;
void VisitAttrs(AttrVisitor* v) {
v->Visit("data_alignment", &data_alignment);
v->Visit("offset_factor", &offset_factor);
v->Visit("double_buffer_split_loop", &double_buffer_split_loop);
v->Visit("auto_unroll_max_step", &auto_unroll_max_step);
v->Visit("auto_unroll_max_depth", &auto_unroll_max_depth);
v->Visit("auto_unroll_max_extent", &auto_unroll_max_extent);
v->Visit("unroll_explicit", &unroll_explicit);
v->Visit("restricted_func", &restricted_func);
v->Visit("detect_global_barrier", &detect_global_barrier);
v->Visit("partition_const_loop", &partition_const_loop);
v->Visit("dump_pass_ir", &dump_pass_ir);
v->Visit("instrument_bound_checkers", &instrument_bound_checkers);
v->Visit("disable_select_rewriting", &disable_select_rewriting);
v->Visit("disable_vectorize", &disable_vectorize);
v->Visit("disable_assert", &disable_assert);
}
static constexpr const char* _type_key = "BuildConfig";
TVM_DECLARE_FINAL_OBJECT_INFO(BuildConfigNode, Object);
};
/*!
* \brief Build configuration for compilations.
*/
class BuildConfig : public ::tvm::ObjectRef {
public:
BuildConfig() {}
explicit BuildConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
const BuildConfigNode* operator->() const {
return static_cast<const BuildConfigNode*>(get());
}
BuildConfigNode* operator->() {
return static_cast<BuildConfigNode*>(get_mutable());
}
/*!
* \brief Construct a BuildConfig containing a empty build config node.
* \return The new BuildConfig
*/
TVM_DLL static BuildConfig Create();
/*!
* \brief Get the current BuildConfig context from thread local storage, or a default
* configuration if a BuildConfig scope has not been entered.
* \return The configuration that is the current context.
*/
TVM_DLL static BuildConfig Current();
using ContainerType = BuildConfigNode;
class Internal;
private:
// Enable with syntax.
friend class With<BuildConfig>;
/*!
* \brief Push a new BuildConfig context onto the thread local stack.
*/
TVM_DLL void EnterWithScope();
/*!
* \brief Pop a build config off the thread local context stack,
* restoring the previous configuration as the current context.
*/
TVM_DLL void ExitWithScope();
};
} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
......@@ -29,7 +29,6 @@ There can be internal header files within each module that sit in src.
- arith: Arithmetic expression and set simplification.
- top: tensor operation DSL for compute and schedule.
- relay: Relay IR, high-level optimization.
- codegen: The code generator.
- autotvm: The auto-tuning module.
- contrib: Contrib extension libraries.
- api: API function registration.
......@@ -23,7 +23,7 @@
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/runtime/registry.h>
......
......@@ -29,7 +29,7 @@
#include <tvm/top/schedule.h>
#include <tvm/runtime/registry.h>
#include <tvm/build_module.h>
#include <tvm/driver/driver.h>
#include <tvm/tir/data_layout.h>
......
......@@ -26,7 +26,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/top/schedule.h>
#include <map>
......
......@@ -19,13 +19,13 @@
/*!
* Compile executable modules.
* \file build_module.cc
* \file driver.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/build_module.h>
#include <tvm/driver/driver.h>
#include <tvm/top/operation.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/runtime/registry.h>
#include <algorithm>
......@@ -39,8 +39,6 @@ using runtime::TVMRetValue;
using runtime::PackedFunc;
using tir::LoweredFunc;
TVM_REGISTER_NODE_TYPE(GenericFuncNode);
bool LLVMEnabled() {
const runtime::PackedFunc* pf = runtime::Registry::Get("codegen.build_llvm");
return pf != nullptr;
......@@ -343,280 +341,4 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
return build(inputs, target_host, config);
}
BuildConfig BuildConfig::Create() {
return BuildConfig(make_object<BuildConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMBuildConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
BuildConfig default_config;
/*! \brief The current build config context */
std::stack<BuildConfig> context_stack;
TVMBuildConfigThreadLocalEntry() :
default_config(BuildConfig::Create()) {
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
void BuildConfig::EnterWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void BuildConfig::ExitWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
tvm::BuildConfig BuildConfig::Current() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
return entry->default_config;
}
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
p->stream << "offset_factor=" << op->offset_factor << ", ";
p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", ";
p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << "disable_assert=" << op->disable_assert;
p->stream << ")";
});
struct GenericFunc::Manager {
std::unordered_map<std::string, GenericFunc> fmap;
// mutex
std::mutex mutex;
Manager() {
}
static Manager* Global() {
static Manager inst;
return &inst;
}
};
GenericFunc GenericFunc::Get(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) {
auto f = make_object<GenericFuncNode>();
f->name_ = name;
auto gf = GenericFunc(f);
m->fmap[name] = gf;
return gf;
} else {
return it->second;
}
}
void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name;
func->name_ = name;
m->fmap[name] = func;
}
GenericFunc& GenericFunc::set_default(const PackedFunc value,
bool allow_override) {
auto node = static_cast<GenericFuncNode*>(operator->());
if (!allow_override) {
CHECK(node->generic_func_ == nullptr)
<< "Generic function already registered for " << node->name_;
}
node->generic_func_ = value;
return *this;
}
GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
const PackedFunc value,
bool allow_override) {
for (auto &t : tags) {
if (!allow_override) {
auto iter = (*this)->dispatch_dict_.find(t);
CHECK(iter == (*this)->dispatch_dict_.end())
<< "Tag " << t << " already registered for schedule factory " << (*this)->name_;
}
(*this)->dispatch_dict_[t] = value;
}
return *this;
}
void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
auto node = static_cast<const GenericFuncNode*>(get());
auto target = Target::Current(true);
PackedFunc func;
if (target.defined()) {
for (auto &k : target->keys()) {
auto iter = node->dispatch_dict_.find(k);
if (iter != node->dispatch_dict_.end()) {
func = iter->second;
break;
}
}
}
if (func == nullptr) {
CHECK(node->generic_func_ != nullptr) << "No generic function registered for " << node->name_;
func = node->generic_func_;
}
func.CallPacked(args, ret);
}
TVM_REGISTER_GLOBAL("_GetCurrentBuildConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current();
});
class BuildConfig::Internal {
public:
static void EnterScope(BuildConfig target) {
target.EnterWithScope();
}
static void ExitScope(BuildConfig target) {
target.ExitWithScope();
}
};
TVM_REGISTER_GLOBAL("_EnterBuildConfigScope")
.set_body_typed(BuildConfig::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitBuildConfigScope")
.set_body_typed(BuildConfig::Internal::ExitScope);
TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
}
cfg->add_lower_pass = add_lower_pass;
});
TVM_REGISTER_GLOBAL("_BuildConfigGetAddLowerPassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
// Return one of the following:
// * Size of add_lower_pass if num_args == 1
// * Phase index of pass if args are (config, index, true)
// * Function of pass if args are (config, index, false)
BuildConfig cfg = args[0];
if (args.num_args == 1) {
*ret = static_cast<int64_t>(cfg->add_lower_pass.size());
} else {
int index = args[1];
bool get_phase = args[2];
auto item = cfg->add_lower_pass[index];
if (get_phase) {
*ret = item.first;
} else {
*ret = item.second;
}
}
});
TVM_REGISTER_GLOBAL("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(make_object<GenericFuncNode>());
});
TVM_REGISTER_GLOBAL("_GenericFuncGetGlobal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string func_name = args[0];
*ret = GenericFunc::Get(func_name);
});
TVM_REGISTER_GLOBAL("_GenericFuncSetDefault")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
bool allow_override = args[2];
generic_func
.set_default(*func, allow_override);
});
TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
Array<PrimExpr> tags = args[2];
bool allow_override = args[3];
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value);
}
generic_func
.register_func(tags_vector, *func, allow_override);
});
TVM_REGISTER_GLOBAL("_GenericFuncCallFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1);
generic_func
.CallPacked(func_args, ret);
});
TVM_REGISTER_GLOBAL("_GetCurrentTarget")
.set_body([](TVMArgs args, TVMRetValue* ret) {
bool allow_not_defined = args[0];
*ret = Target::Current(allow_not_defined);
});
class Target::Internal {
public:
static void EnterScope(Target target) {
target.EnterWithScope();
}
static void ExitScope(Target target) {
target.ExitWithScope();
}
};
TVM_REGISTER_GLOBAL("_EnterTargetScope")
.set_body_typed(Target::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitTargetScope")
.set_body_typed(Target::Internal::ExitScope);
} // namespace tvm
......@@ -22,7 +22,7 @@
* \brief Code generation for TVM's graph runtime.
*/
#include <tvm/relay/analysis.h>
#include <tvm/build_module.h>
#include <tvm/driver/driver.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/expr.h>
......
......@@ -25,6 +25,7 @@
#include <tvm/top/schedule.h>
#include <tvm/top/operation.h>
#include <tvm/top/schedule_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
......@@ -32,6 +33,8 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/driver/driver.h>
#include <topi/tags.h>
#include <utility>
#include <limits>
......
......@@ -30,6 +30,8 @@
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/debug.h>
#include <tvm/relay/feature.h>
#include <tvm/driver/driver.h>
#include "compile_engine.h"
namespace tvm {
......
......@@ -27,8 +27,8 @@
#include <dmlc/json.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/build_module.h>
#include <tvm/codegen.h>
#include <tvm/driver/driver.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/top/operation.h>
......
......@@ -31,6 +31,8 @@
#include <tvm/relay/transform.h>
#include <tvm/runtime/vm.h>
#include <tvm/relay/attrs/memory.h>
#include <tvm/driver/driver.h>
#include <iostream>
#include <memory>
#include <string>
......
......@@ -25,7 +25,6 @@
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/nn.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/build_module.h>
#include <vector>
#include "../op_common.h"
......
......@@ -21,13 +21,14 @@
* Common build utilities
* \file build_common.h
*/
#ifndef TVM_CODEGEN_BUILD_COMMON_H_
#define TVM_CODEGEN_BUILD_COMMON_H_
#ifndef TVM_TARGET_BUILD_COMMON_H_
#define TVM_TARGET_BUILD_COMMON_H_
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/lowered_func.h>
#include <unordered_map>
#include <string>
#include "../runtime/meta_data.h"
......@@ -36,9 +37,9 @@ namespace tvm {
namespace codegen {
// Extract function information from device function.
inline std::unordered_map<std::string, runtime::FunctionInfo>
ExtractFuncInfo(const Array<LoweredFunc>& funcs) {
ExtractFuncInfo(const Array<tir::LoweredFunc>& funcs) {
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (LoweredFunc f : funcs) {
for (tir::LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(f->args[i].dtype());
......@@ -52,4 +53,4 @@ ExtractFuncInfo(const Array<LoweredFunc>& funcs) {
}
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_BUILD_COMMON_H_
#endif // TVM_TARGET_BUILD_COMMON_H_
......@@ -21,12 +21,12 @@
* \file codegen.cc
* \brief Common utilities to generated C style code.
*/
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/target/target.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/build_module.h>
#include <dmlc/memory_io.h>
#include <sstream>
#include <vector>
......
......@@ -17,8 +17,8 @@
* under the License.
*/
#ifndef TVM_CODEGEN_DATATYPE_REGISTRY_H_
#define TVM_CODEGEN_DATATYPE_REGISTRY_H_
#ifndef TVM_TARGET_DATATYPE_REGISTRY_H_
#define TVM_TARGET_DATATYPE_REGISTRY_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
......@@ -159,4 +159,4 @@ DEFINE_GET_LOWER_FUNC_(GE)
} // namespace datatype
} // namespace tvm
#endif // TVM_CODEGEN_DATATYPE_REGISTRY_H_
#endif // TVM_TARGET_DATATYPE_REGISTRY_H_
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file src/target/generic_func.cc
*/
#include <dmlc/thread_local.h>
#include <tvm/runtime/registry.h>
#include <tvm/node/node.h>
#include <tvm/node/printer.h>
#include <tvm/target/target.h>
#include <tvm/target/generic_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/expr.h>
#include <algorithm>
#include <mutex>
#include <stack>
namespace tvm {
TVM_REGISTER_NODE_TYPE(GenericFuncNode);
struct GenericFunc::Manager {
std::unordered_map<std::string, GenericFunc> fmap;
// mutex
std::mutex mutex;
Manager() {
}
static Manager* Global() {
static Manager inst;
return &inst;
}
};
GenericFunc GenericFunc::Get(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) {
auto f = make_object<GenericFuncNode>();
f->name_ = name;
auto gf = GenericFunc(f);
m->fmap[name] = gf;
return gf;
} else {
return it->second;
}
}
void GenericFunc::RegisterGenericFunc(GenericFunc func, const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
CHECK(it == m->fmap.end()) << "GenericFunc already registered " << name;
func->name_ = name;
m->fmap[name] = func;
}
GenericFunc& GenericFunc::set_default(const PackedFunc value,
bool allow_override) {
auto node = static_cast<GenericFuncNode*>(operator->());
if (!allow_override) {
CHECK(node->generic_func_ == nullptr)
<< "Generic function already registered for " << node->name_;
}
node->generic_func_ = value;
return *this;
}
GenericFunc& GenericFunc::register_func(const std::vector<std::string>& tags,
const PackedFunc value,
bool allow_override) {
for (auto &t : tags) {
if (!allow_override) {
auto iter = (*this)->dispatch_dict_.find(t);
CHECK(iter == (*this)->dispatch_dict_.end())
<< "Tag " << t << " already registered for schedule factory " << (*this)->name_;
}
(*this)->dispatch_dict_[t] = value;
}
return *this;
}
void GenericFunc::CallPacked(TVMArgs args, TVMRetValue* ret) const {
auto node = static_cast<const GenericFuncNode*>(get());
auto target = Target::Current(true);
PackedFunc func;
if (target.defined()) {
for (auto &k : target->keys()) {
auto iter = node->dispatch_dict_.find(k);
if (iter != node->dispatch_dict_.end()) {
func = iter->second;
break;
}
}
}
if (func == nullptr) {
CHECK(node->generic_func_ != nullptr) << "No generic function registered for " << node->name_;
func = node->generic_func_;
}
func.CallPacked(args, ret);
}
TVM_REGISTER_GLOBAL("_GenericFuncCreate")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = GenericFunc(make_object<GenericFuncNode>());
});
TVM_REGISTER_GLOBAL("_GenericFuncGetGlobal")
.set_body([](TVMArgs args, TVMRetValue* ret) {
std::string func_name = args[0];
*ret = GenericFunc::Get(func_name);
});
TVM_REGISTER_GLOBAL("_GenericFuncSetDefault")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
bool allow_override = args[2];
generic_func
.set_default(*func, allow_override);
});
TVM_REGISTER_GLOBAL("_GenericFuncRegisterFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
// Intentionally copy and not de-allocate it, to avoid free pyobject during shutdown
PackedFunc* func = new PackedFunc(args[1].operator PackedFunc());
Array<PrimExpr> tags = args[2];
bool allow_override = args[3];
std::vector<std::string> tags_vector;
for (auto& tag : tags) {
tags_vector.push_back(tag.as<tvm::tir::StringImmNode>()->value);
}
generic_func
.register_func(tags_vector, *func, allow_override);
});
TVM_REGISTER_GLOBAL("_GenericFuncCallFunc")
.set_body([](TVMArgs args, TVMRetValue* ret) {
GenericFunc generic_func = args[0];
TVMArgs func_args(&args.values[1], &args.type_codes[1], args.num_args - 1);
generic_func
.CallPacked(func_args, ret);
});
} // namespace tvm
......@@ -21,8 +21,8 @@
* \file intrin_rule.h
* \brief Utility to generate intrinsic rules
*/
#ifndef TVM_CODEGEN_INTRIN_RULE_H_
#define TVM_CODEGEN_INTRIN_RULE_H_
#ifndef TVM_TARGET_INTRIN_RULE_H_
#define TVM_TARGET_INTRIN_RULE_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/expr.h>
......@@ -72,4 +72,4 @@ inline void DispatchExtern(const TVMArgs& args, TVMRetValue* rv) {
} // namespace intrin
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_INTRIN_RULE_H_
#endif // TVM_TARGET_INTRIN_RULE_H_
......@@ -28,7 +28,6 @@
#include <tvm/runtime/registry.h>
#include "codegen_llvm.h"
#include "../build_common.h"
#include "../codegen_source_base.h"
#include "../../runtime/rocm/rocm_module.h"
namespace tvm {
......
......@@ -21,8 +21,8 @@
* \file codegen_blob.h
* \brief Code Generation of blob data
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_
#define TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_
#ifndef TVM_TARGET_LLVM_CODEGEN_BLOB_H_
#define TVM_TARGET_LLVM_CODEGEN_BLOB_H_
#ifdef TVM_LLVM_VERSION
#include <utility>
#include <memory>
......@@ -48,4 +48,4 @@ std::pair<std::unique_ptr<llvm::Module>,
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_CODEGEN_BLOB_H_
#endif // TVM_TARGET_LLVM_CODEGEN_BLOB_H_
......@@ -21,8 +21,8 @@
* \file codegen_llvm_cpu.h
* \brief Common base class for generating into LLVM IR on CPU host.
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_CPU_H_
#define TVM_CODEGEN_LLVM_CODEGEN_CPU_H_
#ifndef TVM_TARGET_LLVM_CODEGEN_CPU_H_
#define TVM_TARGET_LLVM_CODEGEN_CPU_H_
#include <utility>
#include <vector>
......@@ -153,4 +153,4 @@ class CodeGenCPU : public CodeGenLLVM {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_LLVM_CODEGEN_CPU_H_
#endif // TVM_TARGET_LLVM_CODEGEN_CPU_H_
......@@ -21,8 +21,8 @@
* \file codegen_llvm.h
* \brief Common base class for generating into LLVM IR
*/
#ifndef TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#define TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#ifndef TVM_TARGET_LLVM_CODEGEN_LLVM_H_
#define TVM_TARGET_LLVM_CODEGEN_LLVM_H_
#ifdef TVM_LLVM_VERSION
#include <tvm/arith/analyzer.h>
......@@ -30,7 +30,7 @@
#include <tvm/tir/stmt.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <memory>
#include <utility>
#include <vector>
......@@ -311,4 +311,4 @@ class CodeGenLLVM :
} // namespace codegen
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_CODEGEN_LLVM_H_
#endif // TVM_TARGET_LLVM_CODEGEN_LLVM_H_
......@@ -21,14 +21,14 @@
* \file intrin_rule_llvm.h
* \brief Common utilities for llvm intrinsics.
*/
#ifndef TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#define TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#ifndef TVM_TARGET_LLVM_INTRIN_RULE_LLVM_H_
#define TVM_TARGET_LLVM_INTRIN_RULE_LLVM_H_
#ifdef TVM_LLVM_VERSION
#include <tvm/tir/expr.h>
#include <tvm/runtime/registry.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <string>
#include "llvm_common.h"
......@@ -72,4 +72,4 @@ inline void DispatchLLVMIntrin(const TVMArgs& targs, TVMRetValue* rv) {
} // namespace tvm
#endif // LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_INTRIN_RULE_LLVM_H_
#endif // TVM_TARGET_LLVM_INTRIN_RULE_LLVM_H_
......@@ -21,8 +21,8 @@
* \file llvm_common.h
* \brief Common utilities for llvm initialization.
*/
#ifndef TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#define TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#ifndef TVM_TARGET_LLVM_LLVM_COMMON_H_
#define TVM_TARGET_LLVM_LLVM_COMMON_H_
#ifdef TVM_LLVM_VERSION
#include <llvm/ExecutionEngine/MCJIT.h>
......@@ -114,4 +114,4 @@ GetLLVMTargetMachine(const std::string& target_str, bool allow_null = false);
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
#endif // TVM_CODEGEN_LLVM_LLVM_COMMON_H_
#endif // TVM_TARGET_LLVM_LLVM_COMMON_H_
......@@ -25,7 +25,7 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <mutex>
#include "llvm_common.h"
#include "codegen_llvm.h"
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -20,7 +20,7 @@
/*!
* Optional module when build aocl is switched to off
*/
#include "../codegen_source_base.h"
#include "../source/codegen_source_base.h"
#include "../../runtime/opencl/opencl_module.h"
namespace tvm {
......
......@@ -31,8 +31,8 @@
#include <nvrtc.h>
#include <cstdlib>
#include "../codegen_cuda.h"
#include "../build_common.h"
#include "../source/codegen_cuda.h"
#include "../../runtime/cuda/cuda_common.h"
#include "../../runtime/cuda/cuda_module.h"
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -20,7 +20,7 @@
/*!
* Optional module when build metal is switched to off
*/
#include "../codegen_source_base.h"
#include "../source/codegen_source_base.h"
#include "../../runtime/metal/metal_module.h"
namespace tvm {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -20,7 +20,7 @@
/*!
* Optional module when build opencl is switched to off
*/
#include "../codegen_source_base.h"
#include "../source/codegen_source_base.h"
#include "../../runtime/opencl/opencl_module.h"
namespace tvm {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -20,7 +20,7 @@
/*!
* Optional module when build opencl is switched to off
*/
#include "../codegen_source_base.h"
#include "../source/codegen_source_base.h"
#include "../../runtime/opengl/opengl_module.h"
namespace tvm {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -20,7 +20,7 @@
/*!
* Optional module when build rocm is switched to off
*/
#include "../codegen_source_base.h"
#include "../source/codegen_source_base.h"
#include "../../runtime/rocm/rocm_module.h"
namespace tvm {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -20,7 +20,7 @@
/*!
* Optional module when build opencl is switched to off
*/
#include "../codegen_source_base.h"
#include "../source/codegen_source_base.h"
#include "../../runtime/opencl/opencl_module.h"
namespace tvm {
......
......@@ -20,13 +20,13 @@
/*!
* \file codegen_aocl.cc
*/
#include <tvm/build_module.h>
#include <tvm/target/target.h>
#include <vector>
#include <string>
#include "codegen_opencl.h"
#include "build_common.h"
#include "../runtime/opencl/aocl/aocl_module.h"
#include "../runtime/file_util.h"
#include "../build_common.h"
#include "../../runtime/opencl/aocl/aocl_module.h"
#include "../../runtime/file_util.h"
namespace tvm {
namespace codegen {
......
......@@ -23,8 +23,8 @@
#include <iomanip>
#include <cctype>
#include "codegen_c.h"
#include "../arith/compute_expr.h"
#include "../tir/pass/ir_util.h"
#include "../../arith/compute_expr.h"
#include "../../tir/pass/ir_util.h"
namespace tvm {
namespace codegen {
......
......@@ -21,13 +21,13 @@
* \file codegen_c.h
* \brief Common utilities to generated C style code.
*/
#ifndef TVM_CODEGEN_CODEGEN_C_H_
#define TVM_CODEGEN_CODEGEN_C_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/lowered_func.h>
#include <string>
#include <vector>
......@@ -214,4 +214,4 @@ class CodeGenC :
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_C_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_C_H_
......@@ -20,11 +20,11 @@
/*!
* \file codegen_c_host.cc
*/
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <vector>
#include <string>
#include "codegen_c_host.h"
#include "build_common.h"
#include "../build_common.h"
namespace tvm {
namespace codegen {
......
......@@ -21,10 +21,10 @@
* \file codegen_c_host.h
* \brief Generate C host code.
*/
#ifndef TVM_CODEGEN_CODEGEN_C_HOST_H_
#define TVM_CODEGEN_CODEGEN_C_HOST_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
#define TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <string>
#include "codegen_c.h"
......@@ -75,4 +75,4 @@ class CodeGenCHost final : public CodeGenC {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_C_HOST_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_C_HOST_H_
......@@ -21,10 +21,10 @@
* \file codegen_cuda.h
* \brief Utility to generate cuda code
*/
#ifndef TVM_CODEGEN_CODEGEN_CUDA_H_
#define TVM_CODEGEN_CODEGEN_CUDA_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_CUDA_H_
#define TVM_TARGET_SOURCE_CODEGEN_CUDA_H_
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/tir/expr.h>
#include <string>
#include <unordered_map>
......@@ -93,4 +93,4 @@ class CodeGenCUDA final : public CodeGenC {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_CUDA_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_CUDA_H_
......@@ -24,9 +24,9 @@
#include <string>
#include <algorithm>
#include "codegen_metal.h"
#include "build_common.h"
#include "../runtime/metal/metal_module.h"
#include "../runtime/thread_storage_scope.h"
#include "../build_common.h"
#include "../../runtime/metal/metal_module.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace codegen {
......
......@@ -21,10 +21,10 @@
* \file codegen_metal.h
* \brief Generate Metal device code.
*/
#ifndef TVM_CODEGEN_CODEGEN_METAL_H_
#define TVM_CODEGEN_CODEGEN_METAL_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_METAL_H_
#define TVM_TARGET_SOURCE_CODEGEN_METAL_H_
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <string>
#include "codegen_c.h"
......@@ -60,4 +60,4 @@ class CodeGenMetal final : public CodeGenC {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_METAL_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_METAL_H_
......@@ -24,9 +24,9 @@
#include <vector>
#include <string>
#include "codegen_opencl.h"
#include "build_common.h"
#include "../runtime/thread_storage_scope.h"
#include "../runtime/opencl/opencl_module.h"
#include "../build_common.h"
#include "../../runtime/thread_storage_scope.h"
#include "../../runtime/opencl/opencl_module.h"
namespace tvm {
namespace codegen {
......
......@@ -21,10 +21,10 @@
* \file codegen_opencl.h
* \brief Generate OpenCL device code.
*/
#ifndef TVM_CODEGEN_CODEGEN_OPENCL_H_
#define TVM_CODEGEN_CODEGEN_OPENCL_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
#define TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <string>
#include "codegen_c.h"
......@@ -68,4 +68,4 @@ class CodeGenOpenCL final : public CodeGenC {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_OPENCL_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_OPENCL_H_
......@@ -28,8 +28,8 @@
#include <utility>
#include <unordered_map>
#include "codegen_opengl.h"
#include "build_common.h"
#include "../runtime/thread_storage_scope.h"
#include "../build_common.h"
#include "../../runtime/thread_storage_scope.h"
namespace tvm {
namespace codegen {
......
......@@ -21,15 +21,15 @@
* \file codegen_opengl.h
* \brief Generate OpenGL device code.
*/
#ifndef TVM_CODEGEN_CODEGEN_OPENGL_H_
#define TVM_CODEGEN_CODEGEN_OPENGL_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_
#define TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <string>
#include <unordered_set>
#include <unordered_map>
#include "codegen_c.h"
#include "../runtime/opengl/opengl_module.h"
#include "../../runtime/opengl/opengl_module.h"
namespace tvm {
namespace codegen {
......@@ -66,4 +66,4 @@ class CodeGenOpenGL final : public CodeGenC {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_OPENGL_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_OPENGL_H_
......@@ -21,17 +21,17 @@
* \file codegen_source_base.h
* \brief Common utilities to source code in text form.
*/
#ifndef TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
#define TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
#define TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <string>
#include <vector>
#include <functional>
#include <unordered_map>
#include "../runtime/meta_data.h"
#include "../../runtime/meta_data.h"
namespace tvm {
namespace codegen {
......@@ -154,4 +154,4 @@ runtime::Module DeviceSourceModuleCreate(
std::function<std::string(const std::string&)> fget_source = nullptr);
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
......@@ -20,12 +20,11 @@
/*!
* \file codegen_vhls.cc
*/
#include <tvm/build_module.h>
#include <vector>
#include <string>
#include "codegen_vhls.h"
#include "build_common.h"
#include "../runtime/opencl/sdaccel/sdaccel_module.h"
#include "../build_common.h"
#include "../../runtime/opencl/sdaccel/sdaccel_module.h"
namespace tvm {
namespace codegen {
......
......@@ -21,10 +21,11 @@
* \file codegen_vhls.h
* \brief Utility to generate vhls code
*/
#ifndef TVM_CODEGEN_CODEGEN_VHLS_H_
#define TVM_CODEGEN_CODEGEN_VHLS_H_
#ifndef TVM_TARGET_SOURCE_CODEGEN_VHLS_H_
#define TVM_TARGET_SOURCE_CODEGEN_VHLS_H_
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <tvm/target/target.h>
#include <tvm/tir/expr.h>
#include <string>
#include "codegen_c.h"
......@@ -45,4 +46,4 @@ class CodeGenVivadoHLS final : public CodeGenC {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_VHLS_H_
#endif // TVM_TARGET_SOURCE_CODEGEN_VHLS_H_
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file intrin_rule_aocl.cc
* \brief AOCL intrinsic rules.
*/
#include "intrin_rule.h"
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file intrin_rule_cuda.cc
* \brief CUDA intrinsic rules.
*/
#include "intrin_rule.h"
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file intrin_rule_metal.cc
* \brief Metal intrinsic rules.
*/
#include "intrin_rule.h"
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules.
*/
#include "intrin_rule.h"
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file intrin_rule_opencl.cc
* \brief OpenCL intrinsic rules.
*/
#include "intrin_rule.h"
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -21,7 +21,7 @@
* \file intrin_rule_vhls.cc
* \brief VHLS intrinsic rules.
*/
#include "intrin_rule.h"
#include "../intrin_rule.h"
namespace tvm {
namespace codegen {
......
......@@ -21,8 +21,8 @@
* \file cuda_half_t.h
* \brief half_t (fp16) definition for cuda codegen.
*/
#ifndef TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
#define TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
#ifndef TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
#define TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
static constexpr const char* _cuda_half_t_def = R"(
typedef unsigned short uint16_t;
......@@ -295,4 +295,4 @@ __pack_half2(const half x, const half y) {
}
)";
#endif // TVM_CODEGEN_LITERAL_CUDA_HALF_T_H_
#endif // TVM_TARGET_SOURCE_LITERAL_CUDA_HALF_T_H_
......@@ -24,8 +24,8 @@
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include "codegen_source_base.h"
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
#include "../../runtime/file_util.h"
#include "../../runtime/meta_data.h"
namespace tvm {
namespace codegen {
......
......@@ -21,8 +21,8 @@
* \file ir_builder.h
* \brief Utility for building SPIRV code block
*/
#ifndef TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#define TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#ifndef TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
#define TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
......@@ -150,4 +150,4 @@ class CodeGenSPIRV:
} // namespace tvm
#endif // TVM_CODEGEN_SPIRV_CODEGEN_SPIRV_H_
#endif // TVM_TARGET_SPIRV_CODEGEN_SPIRV_H_
......@@ -21,8 +21,8 @@
* \file ir_builder.h
* \brief Utility for building SPIRV code block
*/
#ifndef TVM_CODEGEN_SPIRV_IR_BUILDER_H_
#define TVM_CODEGEN_SPIRV_IR_BUILDER_H_
#ifndef TVM_TARGET_SPIRV_IR_BUILDER_H_
#define TVM_TARGET_SPIRV_IR_BUILDER_H_
#include <tvm/runtime/packed_func.h>
#include <tvm/tir/expr.h>
......@@ -620,4 +620,4 @@ class IRBuilder {
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_SPIRV_IR_BUILDER_H_
#endif // TVM_TARGET_SPIRV_IR_BUILDER_H_
......@@ -21,13 +21,13 @@
* \file codegen_stack_vm.h
* \brief Codegen into Simple Stack VM.
*/
#ifndef TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#define TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#ifndef TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
#define TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/lowered_func.h>
#include <tvm/codegen.h>
#include <tvm/target/codegen.h>
#include <string>
#include <vector>
#include <unordered_map>
......@@ -164,4 +164,4 @@ class CodeGenStackVM
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_STACKVM_CODEGEN_STACKVM_H_
#endif // TVM_TARGET_STACKVM_CODEGEN_STACKVM_H_
......@@ -269,6 +269,27 @@ tvm::Target Target::Current(bool allow_not_defined) {
return Target();
}
TVM_REGISTER_GLOBAL("_GetCurrentTarget")
.set_body([](TVMArgs args, TVMRetValue* ret) {
bool allow_not_defined = args[0];
*ret = Target::Current(allow_not_defined);
});
class Target::Internal {
public:
static void EnterScope(Target target) {
target.EnterWithScope();
}
static void ExitScope(Target target) {
target.ExitWithScope();
}
};
TVM_REGISTER_GLOBAL("_EnterTargetScope")
.set_body_typed(Target::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitTargetScope")
.set_body_typed(Target::Internal::ExitScope);
namespace target {
std::vector<std::string> MergeOptions(std::vector<std::string> opts,
const std::vector<std::string>& new_opts) {
......@@ -316,4 +337,125 @@ Target ext_dev(const std::vector<std::string>& options) {
return CreateTarget("ext_dev", options);
}
} // namespace target
BuildConfig BuildConfig::Create() {
return BuildConfig(make_object<BuildConfigNode>());
}
/*! \brief Entry to hold the BuildConfig context stack. */
struct TVMBuildConfigThreadLocalEntry {
/*! \brief The default build config if the stack is empty */
BuildConfig default_config;
/*! \brief The current build config context */
std::stack<BuildConfig> context_stack;
TVMBuildConfigThreadLocalEntry() :
default_config(BuildConfig::Create()) {
}
};
/*! \brief Thread local store to hold the BuildConfig context stack. */
typedef dmlc::ThreadLocalStore<TVMBuildConfigThreadLocalEntry> TVMBuildConfigThreadLocalStore;
void BuildConfig::EnterWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
entry->context_stack.push(*this);
}
void BuildConfig::ExitWithScope() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
CHECK(!entry->context_stack.empty());
CHECK(entry->context_stack.top().same_as(*this));
entry->context_stack.pop();
}
tvm::BuildConfig BuildConfig::Current() {
TVMBuildConfigThreadLocalEntry *entry = TVMBuildConfigThreadLocalStore::Get();
if (entry->context_stack.size() > 0) {
return entry->context_stack.top();
}
return entry->default_config;
}
TVM_REGISTER_NODE_TYPE(BuildConfigNode);
TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable)
.set_dispatch<BuildConfigNode>([](const ObjectRef& node, NodePrinter* p) {
auto* op = static_cast<const BuildConfigNode*>(node.get());
p->stream << "build_config(";
p->stream << "data_alignment=" << op->data_alignment << ", ";
p->stream << "offset_factor=" << op->offset_factor << ", ";
p->stream << "double_buffer_split_loop=" << op->double_buffer_split_loop << ", ";
p->stream << "auto_unroll_max_step=" << op->auto_unroll_max_step << ", ";
p->stream << "auto_unroll_max_depth=" << op->auto_unroll_max_depth << ", ";
p->stream << "auto_unroll_max_extent=" << op->auto_unroll_max_extent << ", ";
p->stream << "unroll_explicit=" << op->unroll_explicit << ", ";
p->stream << "restricted_func=" << op->restricted_func << ", ";
p->stream << "detect_global_barrier=" << op->detect_global_barrier << ", ";
p->stream << "partition_const_loop=" << op->partition_const_loop << ", ";
p->stream << "dump_pass_ir=" << op->dump_pass_ir << ", ";
p->stream << "instrument_bound_checkers=" << op->instrument_bound_checkers << ", ";
p->stream << "disable_select_rewriting=" << op->disable_select_rewriting;
p->stream << "disable_vectorize=" << op->disable_vectorize;
p->stream << "disable_assert=" << op->disable_assert;
p->stream << ")";
});
TVM_REGISTER_GLOBAL("_GetCurrentBuildConfig")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = BuildConfig::Current();
});
class BuildConfig::Internal {
public:
static void EnterScope(BuildConfig target) {
target.EnterWithScope();
}
static void ExitScope(BuildConfig target) {
target.ExitWithScope();
}
};
TVM_REGISTER_GLOBAL("_EnterBuildConfigScope")
.set_body_typed(BuildConfig::Internal::EnterScope);
TVM_REGISTER_GLOBAL("_ExitBuildConfigScope")
.set_body_typed(BuildConfig::Internal::ExitScope);
TVM_REGISTER_GLOBAL("_BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
std::vector< std::pair<int, PackedFunc> > add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
args[i + 1].operator tvm::runtime::PackedFunc()));
}
cfg->add_lower_pass = add_lower_pass;
});
TVM_REGISTER_GLOBAL("_BuildConfigGetAddLowerPassInfo")
.set_body([](TVMArgs args, TVMRetValue* ret) {
// Return one of the following:
// * Size of add_lower_pass if num_args == 1
// * Phase index of pass if args are (config, index, true)
// * Function of pass if args are (config, index, false)
BuildConfig cfg = args[0];
if (args.num_args == 1) {
*ret = static_cast<int64_t>(cfg->add_lower_pass.size());
} else {
int index = args[1];
bool get_phase = args[2];
auto item = cfg->add_lower_pass[index];
if (get_phase) {
*ret = item.first;
} else {
*ret = item.second;
}
}
});
} // namespace tvm
......@@ -23,7 +23,7 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include "../../codegen/datatype/registry.h"
#include "../../target/datatype/registry.h"
namespace tvm {
namespace tir {
......
......@@ -29,7 +29,7 @@
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/buffer.h>
#include <tvm/target/target_info.h>
#include <tvm/build_module.h>
#include <tvm/target/target.h>
#include <tvm/runtime/device_api.h>
#include <unordered_map>
#include "ir_util.h"
......
......@@ -22,7 +22,7 @@
#include <topi/cuda/injective.h>
#include <tvm/top/operation.h>
#include <tvm/runtime/registry.h>
#include <tvm/build_module.h>
#include <tvm/driver/driver.h>
#include <string>
#include <cmath>
......
......@@ -18,7 +18,7 @@
*/
#include <gtest/gtest.h>
#include <tvm/build_module.h>
#include <tvm/driver/driver.h>
#include <tvm/top/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
......
......@@ -19,7 +19,7 @@
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/build_module.h>
#include <tvm/driver/driver.h>
#include <tvm/relay/expr.h>
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
......
......@@ -31,7 +31,7 @@
#include <gtest/gtest.h>
#include <topi/generic/injective.h>
#include <tvm/build_module.h>
#include <tvm/driver/driver.h>
#include <tvm/top/operation.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
......
......@@ -25,7 +25,8 @@
#define TOPI_CUDA_DENSE_H_
#include "tvm/top/operation.h"
#include "tvm/build_module.h"
#include "tvm/top/schedule_pass.h"
#include "tvm/target/generic_func.h"
#include "topi/tags.h"
#include "topi/detail/array_utils.h"
#include "topi/nn/dense.h"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment