Commit efae4be0 by Tianqi Chen Committed by GitHub

[MODULE/REFACTOR] Introduce Module for AOT and runtime linking. (#51)

parent 8f240ee7
......@@ -11,7 +11,7 @@ include $(config)
# specify tensor path
.PHONY: clean all test doc
all: lib/libtvm.a lib/libtvm.so
all: lib/libtvm.so lib/libtvm_runtime.so lib/libtvm.a
LIB_HALIDE_IR = HalideIR/lib/libHalideIR.a
......@@ -19,6 +19,11 @@ SRC = $(wildcard src/*.cc src/*/*.cc src/*/*/*.cc)
ALL_OBJ = $(patsubst src/%.cc, build/%.o, $(SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
RUNTIME_SRC = $(wildcard src/runtime/*.cc src/runtime/*/*.cc)
RUNTIME_DEP = $(patsubst src/%.cc, build/%.o, $(RUNTIME_SRC))
ALL_DEP = $(ALL_OBJ) $(LIB_HALIDE_IR)
export LDFLAGS = -pthread -lm
export CFLAGS = -std=c++11 -Wall -O2 -fno-rtti\
-Iinclude -Idmlc-core/include -IHalideIR/src -fPIC -DDMLC_ENABLE_RTTI=0
......@@ -77,15 +82,18 @@ build/%.o: src/%.cc
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
$(CXX) -c $(CFLAGS) -c $< -o $@
lib/libtvm.a: $(ALL_DEP)
lib/libtvm.so: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm.so: $(ALL_DEP)
lib/libtvm_runtime.so: $(RUNTIME_DEP)
@mkdir -p $(@D)
$(CXX) $(CFLAGS) $(FRAMEWORKS) -shared -o $@ $(filter %.o %.a, $^) $(LDFLAGS)
lib/libtvm.a: $(ALL_DEP)
@mkdir -p $(@D)
ar crv $@ $(filter %.o, $?)
$(LIB_HALIDE_IR): LIBHALIDEIR
LIBHALIDEIR:
......
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2017 by Contributors
* \file api_registry.h
* \brief This file defines the TVM API registry.
*
* The API registry stores type-erased functions.
* Each registered function is automatically exposed
* to front-end language(e.g. python).
* Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end.
*
* \code
* // register the function as MyAPIFuncName
* TVM_REGISTER_API(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
* \brief This files include necessary headers to
* be used to register an global API function.
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#include <dmlc/base.h>
#include <string>
#include "./base.h"
#include "./runtime/packed_func.h"
#include "./packed_func_ext.h"
namespace tvm {
/*! \brief Utility to register API. */
class APIRegistry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
*/
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)
private:
/*! \brief name of the function */
std::string name_;
};
#include "./runtime/registry.h"
/*!
* \brief Get API function by name.
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
*
* \param name The name of the function.
* \return the corresponding API function.
* \note It is really PackedFunc::GetGlobal under the hood.
*/
inline PackedFunc GetAPIFunc(const std::string& name) {
return PackedFunc::GetGlobal(name);
}
#define _TVM_REGISTER_VAR_DEF_ \
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_
/*!
* \brief Register API function globally.
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
......@@ -78,8 +22,6 @@ inline PackedFunc GetAPIFunc(const std::string& name) {
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) \
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
::tvm::APIRegistry::__REGISTER__(#OpName)
} // namespace tvm
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)
#endif // TVM_API_REGISTRY_H_
......@@ -10,9 +10,9 @@
#include "./base.h"
#include "./expr.h"
#include "./lowered_func.h"
#include "./api_registry.h"
#include "./runtime/packed_func.h"
namespace tvm {
/*! \brief namespace for lowlevel IR pass and codegen */
namespace codegen {
......@@ -22,41 +22,21 @@ using runtime::TVMArgs;
using runtime::TVMRetValue;
/*!
* \brief Build a stack VM function.
* \param func The LoweredFunc to be build
* \param device_funcs The additional device functions
* \return A packed function representing the func.
*/
PackedFunc BuildStackVM(
LoweredFunc func,
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);
/*!
* \brief Build a LLVM VM function, this is still beta
* \param func The LoweredFunc to be build
* \return A packed function representing the func.
*/
PackedFunc BuildLLVM(LoweredFunc func);
/*!
* \brief Build a CUDA function with NVRTC
* \brief Build a module from array of lowered function.
* \param funcs The functions to be built.
* \param target The target to be built.
* \return The builded module.
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
* \note Calls global API function "_codegen_build_" + target
*/
PackedFunc BuildNVRTC(Array<LoweredFunc> fsplits, std::string host_mode);
runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target);
/*!
* \brief Build a OpenCL function.
*
* \param fsplits The LoweredFuncs to be build (after SplitHostDevice)
* The first element is the host function, followed by device functions.
* \param host_mode The host side compilation mode:
* - "stackvm": use stack vm to interpret host side code.
* \param target The target to be queried.
* \return Whether target is enabled.
*/
PackedFunc BuildOpenCL(Array<LoweredFunc> fsplits, std::string host_mode);
bool TargetEnabled(const std::string& target);
} // namespace codegen
} // namespace tvm
......
......@@ -120,23 +120,14 @@ constexpr const char* tvm_handle_is_null = "tvm_handle_is_null";
/*!
* \brief See pesudo code
*
* int tvm_call_global(name, TVMValue* args) {
* PackedFunc f = PackedFunc::GetGlobal(name);
* f (args, type_code_of(args), len(args));
* int tvm_call_packed(name, TVMValue* args) {
* ModuleNode* env = GetCurrentEnv();
* const PackedFunc* f = env->GetFuncFromEnv(name);
* (*f)(args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_global = "tvm_call_global";
/*!
* \brief See pesudo code
*
* int tvm_call_device(name, TVMValue* args) {
* PackedFunc df = CodeGenEnv->GetDevice(name);
* f (args, type_code_of(args), len(args));
* return 0;
* }
*/
constexpr const char* tvm_call_device = "tvm_call_device";
constexpr const char* tvm_call_packed = "tvm_call_packed";
/*!
* \brief See pesudo code
*
......
......@@ -147,12 +147,15 @@ Stmt LiftAllocate(Stmt stmt);
* \param body The body of the function.
* \param name The name of the function.
* \param api_args Arguments to the function, can be either Var, or Buffer
* \param num_packed_args Number of arguments that are processed in packed form.
* \param num_unpacked_args Number of arguments that
* are processed in plain form instead of packed form.
* \return a LoweredFunc with the specified signiture.
*
* \note
* The function signiture have two cases
*
* let num_packed_args = len(api_args) - num_unpacked_args;
*
* if num_packed_args is zero:
* f(api_arg_0, api_arg_1, .., api_arg_n) where n == len(api_args)
*
......@@ -167,7 +170,7 @@ Stmt LiftAllocate(Stmt stmt);
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args);
int num_unpacked_args);
/*!
* \brief Count number of undefined vars in f.
......
......@@ -72,6 +72,8 @@ class LoweredFuncNode : public FunctionBaseNode {
* constant Expr of given type is used.
*/
Map<Var, Expr> handle_data_type;
/*! \brief Whether this function is packed function */
bool is_packed_func{true};
/*! \brief The body statment of the function */
Stmt body;
/*! \return name of the operation */
......@@ -88,6 +90,7 @@ class LoweredFuncNode : public FunctionBaseNode {
v->Visit("args", &args);
v->Visit("thread_axis", &thread_axis);
v->Visit("handle_data_type", &handle_data_type);
v->Visit("is_packed_func", &is_packed_func);
v->Visit("body", &body);
}
......
......@@ -51,9 +51,10 @@ typedef enum {
kArrayHandle = 5U,
kTVMType = 6U,
kNodeHandle = 7U,
kFuncHandle = 8U,
kStr = 9U,
kBytes = 10U
kModuleHandle = 8U,
kFuncHandle = 9U,
kStr = 10U,
kBytes = 11U
} TVMTypeCode;
/*!
......@@ -140,17 +141,19 @@ typedef struct {
TVMContext ctx;
} TVMArray;
/*!
* \brief The stream that is specific to device
* can be NULL, which indicates the default one.
*/
typedef void* TVMStreamHandle;
/*! \brief Handle to TVM runtime modules. */
typedef void* TVMModuleHandle;
/*! \brief Handle to packed function handle. */
typedef void* TVMFunctionHandle;
/*! \brief Handle to hold return value. */
typedef void* TVMRetValueHandle;
/*! \brief the array handle */
typedef TVMArray* TVMArrayHandle;
/*!
* \brief The stream that is specific to device
* can be NULL, which indicates the default one.
*/
typedef void* TVMStreamHandle;
/*!
* \brief Used for implementing C API function.
......@@ -169,74 +172,87 @@ TVM_DLL void TVMAPISetLastError(const char* msg);
* \return error info
*/
TVM_DLL const char *TVMGetLastError(void);
/*!
* \brief Initialize certain type of devices, this may
* not be necessary for all device types. But is needed for OpenCL.
* \brief Load module from file.
* \param file_name The file name to load the module from.
* \param format The format of the module.
* \param out The result module
*
* \param dev_mask The device mask of device type to be initialized
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return 0 when success, -1 when failure happens
* \note The resulting module do not contain import relation.
* It can be reconstructed by TVMModImport.
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int *out_code);
TVM_DLL int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out);
/*!
* \brief Whether the specified context is enabled.
* \brief Add dep to mod's dependency.
* This allows functions in this module to use modules.
*
* \param ctx The context to be checked.
* \param out_enabled whether the ctx is enabled.
* \return Whether the function is successful.
* \param mod The module handle.
* \param dep The dependent module to be imported.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMContextEnabled(TVMContext ctx,
int* out_enabled);
TVM_DLL int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep);
/*!
* \brief Allocate a nd-array's memory,
* including space of shape, of given spec.
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype The array data type.
* \param ctx The ctx this array sits on.
* \param out The output handle.
* \return 0 when success, -1 when failure happens
* \brief Get function from the module.
* \param mod The module handle.
* \param func_name The name of the function.
* \param query_imports Whether to query imported modules
* \param out The result function, can be NULL if it is not available.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
TVMType dtype,
TVMContext ctx,
TVMArrayHandle* out);
TVM_DLL int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *out);
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens
* \brief Precompile the function under given context.
* Many TVMFunctionHandle is initialized lazily,
* This call eagerly prepares the resources under given context.
* Useful for benchmarking purposes.
*
* \param mod The module handle.
* \param func_name The name of the function.
* \param ctx The context to be precompiled on.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
TVM_DLL int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
* \brief Backend function for modules to get function
* from its environment mod_node (its imports and global function).
*
* The user do should not call TVMFuncFree on func.
*
* \note This API is supposed to be used by backend,
* it is not supposed to be used by user.
*
* \param mod_node The module handle.
* \param func_name The name of the function.
* \param out The result function.
* \return 0 when no error is thrown, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream);
TVM_DLL int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *out);
/*!
* \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
* \brief Free the Module
* \param mod The module to be freed.
*
* \note This may not free up the module's resources.
* If there is active TVMFunctionHandle uses the module
* Or if this module is imported by another active module.
*
* The all functions remains valid until TVMFuncFree is called.
*/
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
TVM_DLL int TVMModFree(TVMModuleHandle mod);
/*!
* \brief Free the function when it is no longer needed.
......@@ -355,6 +371,76 @@ TVM_DLL int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out);
*/
TVM_DLL int TVMFuncListGlobalNames(int *out_size,
const char*** out_array);
// Array related apis for quick proptying
/*!
* \brief Initialize certain type of devices, this may
* not be necessary for all device types. But is needed for OpenCL.
*
* \param dev_mask The device mask of device type to be initialized
* \param option_keys Additional option keys to pass.
* \param option_vals Additional option values to pass
* \param num_options Number of options to be passed into it.
* \param out_code 1: success, 0: already initialized
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
int num_options,
int *out_code);
/*!
* \brief Whether the specified context is enabled.
*
* \param ctx The context to be checked.
* \param out_enabled whether the ctx is enabled.
* \return Whether the function is successful.
*/
TVM_DLL int TVMContextEnabled(TVMContext ctx,
int* out_enabled);
/*!
* \brief Allocate a nd-array's memory,
* including space of shape, of given spec.
*
* \param shape The shape of the array, the data content will be copied to out
* \param ndim The number of dimension of the array.
* \param dtype The array data type.
* \param ctx The ctx this array sits on.
* \param out The output handle.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayAlloc(const tvm_index_t* shape,
tvm_index_t ndim,
TVMType dtype,
TVMContext ctx,
TVMArrayHandle* out);
/*!
* \brief Free the TVM Array.
* \param handle The array handle to be freed.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayFree(TVMArrayHandle handle);
/*!
* \brief Copy the array, both from and to must be valid during the copy.
* \param from The array to be copied from.
* \param to The target space.
* \param stream The stream where the copy happens, can be NULL.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMArrayCopyFromTo(TVMArrayHandle from,
TVMArrayHandle to,
TVMStreamHandle stream);
/*!
* \brief Wait until all computations on stream completes.
* \param ctx The ctx to be synchronized.
* \param stream The stream to be synchronized.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream);
} // TVM_EXTERN_C
#endif // TVM_RUNTIME_C_RUNTIME_API_H_
/*!
* Copyright (c) 2017 by Contributors
* \file module.h
* \brief Runtime container of the functions generated by TVM,
* This is used to support dynamically link, load and save
* functions from different convention under unified API.
*/
#ifndef TVM_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_
#include <memory>
#include <vector>
#include <string>
#include <unordered_map>
#include "./c_runtime_api.h"
namespace tvm {
namespace runtime {
// The internal container of module.
class ModuleNode;
class PackedFunc;
/*!
* \brief Module container of TVM.
*/
class Module {
public:
Module() {}
// constructor from container.
explicit Module(std::shared_ptr<ModuleNode> n)
: node_(n) {}
/*!
* \brief Get packed function from current module by name.
*
* \param name The name of the function.
* \param query_imports Whether also query dependency modules.
* \return The result function.
* This function will return PackedFunc(nullptr) if function do not exist.
*/
PackedFunc GetFunction(const std::string& name, bool query_imports);
/*!
* \brief Import another module into this module.
* \param other The module to be imported.
*
* \note Cyclic dependency is not allowed among modules,
* An error will be thrown when cyclic dependency is detected.
*/
void Import(Module other);
/*!
* \brief Load a module from file.
* \param file_name The name of the host function module.
* \param format The format of the file.
* \note This function won't load the import relationship.
* Re-create import relationship by calling Import.
*/
static Module LoadFromFile(const std::string& file_name,
const std::string& format);
/*! \return internal container */
inline ModuleNode* operator->();
private:
std::shared_ptr<ModuleNode> node_;
};
/*!
* \brief Base node container of module.
* Do not create this directly, instead use Module.
*/
class ModuleNode {
public:
/*! \brief virtual destructor */
virtual ~ModuleNode() {}
/*! \return The module type key */
virtual const char* type_key() const = 0;
/*!
* \brief Eagerly compile the function under certain context,
* assuming that it is used by the current thread.
*
* This is useful for benchmarking to eliminate lazy compilation
* overhead during the first execution of the kernel.
*
* \param name The name of the function.
* \param ctx The context to be executed.
*/
virtual void PreCompile(const std::string& name, TVMContext ctx) = 0;
/*!
* \brief Get a PackedFunc from module.
*
* The PackedFunc may not be fully initialized,
* there might still be first time running overhead when
* executing the function on certain devices.
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
* \note The function will always remain valid.
* If the function need resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) = 0;
/*!
* \brief Save the module to file.
* \param file_name The file to be saved to.
* \param format The format of the file.
*/
virtual void SaveToFile(const std::string& file_name,
const std::string& format) = 0;
/*!
* \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default.
* \return Possible source code when available.
*/
virtual std::string GetSource(
const std::string& format = "") = 0;
/*!
* \brief Get a function from current environment
* The environment includes all the imports as well as Global functions.
*
* \param name name of the function.
* \return The corresponding function.
*/
const PackedFunc* GetFuncFromEnv(const std::string& name);
/*! \return The module it imports from */
const std::vector<Module>& imports() const {
return imports_;
}
private:
friend class Module;
/*! \brief The modules this module depend on */
std::vector<Module> imports_;
/*! \brief Cache used by GetImport */
std::unordered_map<std::string,
std::unique_ptr<PackedFunc> > import_cache_;
};
/*! \brief namespace for constant symbols */
namespace symbol {
/*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Local function to set the device during API entry. */
constexpr const char* tvm_entry_setdevice = "__tvm_entry_setdevice";
/*! \brief Placeholder for the module's entry function. */
constexpr const char* tvm_module_main = "__tvm_main__";
} // packed symbol
// implementations of inline functions.
inline ModuleNode* Module::operator->() {
return node_.get();
}
} // namespace runtime
} // namespace tvm
#include "./packed_func.h"
#endif // TVM_RUNTIME_MODULE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file packed_func.h
* \brief Runtime related c++ class.
* \brief Type-erased function used across TVM API.
*/
#ifndef TVM_RUNTIME_PACKED_FUNC_H_
#define TVM_RUNTIME_PACKED_FUNC_H_
......@@ -15,6 +15,7 @@
#include <memory>
#include <type_traits>
#include "./c_runtime_api.h"
#include "./module.h"
namespace Halide {
// Forward declare type for extensions
......@@ -97,30 +98,14 @@ class PackedFunc {
inline void CallPacked(TVMArgs args, TVMRetValue* rv) const;
/*! \return the internal body function */
inline FType body() const;
/*!
* \brief Register f as into global function table
* \param name The name of the function.
* \param f The function to be registered.
* \return Reference to the registered function.
* \note The returned reference is valid until the end of the program
*/
static const PackedFunc& RegisterGlobal(const std::string& name, PackedFunc f);
/*!
* \brief Get the global function by name.
* \param name The name of the function.
* \return reference to the registered function.
*/
static const PackedFunc& GetGlobal(const std::string& name);
/*!
* \brief Whether the global function exist
* \param name The name of the function.
* \return Whetehr the global function exist.
*/
static bool GlobalExist(const std::string& name);
/*!
* \brief Get the names of currently registered global function.
*/
static std::vector<std::string> ListGlobalNames();
/*! \return Whether the packed function is nullptr */
bool operator==(std::nullptr_t null) const {
return body_ == nullptr;
}
/*! \return Whether the packed function is not nullptr */
bool operator!=(std::nullptr_t null) const {
return body_ != nullptr;
}
private:
/*! \brief internal container of packed function */
......@@ -292,6 +277,10 @@ class TVMArgValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
const TVMValue& value() const {
return value_;
}
......@@ -364,6 +353,10 @@ class TVMRetValue : public TVMPODValue_ {
TVM_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
operator Module() const {
TVM_CHECK_TYPE_CODE(type_code_, kModuleHandle);
return *ptr<Module>();
}
// Assign operators
TVMRetValue& operator=(TVMRetValue&& other) {
this->Clear();
......@@ -415,6 +408,10 @@ class TVMRetValue : public TVMPODValue_ {
this->SwitchToClass(kFuncHandle, f);
return *this;
}
TVMRetValue& operator=(Module m) {
this->SwitchToClass(kModuleHandle, m);
return *this;
}
TVMRetValue& operator=(const TVMRetValue& other) { // NOLINT(*0
this->Assign(other);
return *this;
......@@ -444,6 +441,7 @@ class TVMRetValue : public TVMPODValue_ {
const TVMValue& value() const {
CHECK(type_code_ != kNodeHandle &&
type_code_ != kFuncHandle &&
type_code_ != kModuleHandle &&
type_code_ != kStr) << "TVMRetValue.value can only be used for POD data";
return value_;
}
......@@ -471,6 +469,10 @@ class TVMRetValue : public TVMPODValue_ {
SwitchToClass<PackedFunc>(kFuncHandle, other);
break;
}
case kModuleHandle: {
SwitchToClass<PackedFunc>(kModuleHandle, other);
break;
}
case kNodeHandle: {
SwitchToClass<std::shared_ptr<Node> >(
kNodeHandle, *other.template ptr<std::shared_ptr<Node> >());
......@@ -506,6 +508,7 @@ class TVMRetValue : public TVMPODValue_ {
switch (type_code_) {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kModuleHandle: delete ptr<Module>(); break;
case kNodeHandle: delete ptr<std::shared_ptr<Node> >(); break;
}
type_code_ = kNull;
......@@ -518,12 +521,14 @@ inline const char* TypeCode2Str(int type_code) {
case kInt: return "int";
case kFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kNodeHandle: return "NodeHandle";
case kArrayHandle: return "ArrayHandle";
case kTVMType: return "TVMType";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
}
......@@ -667,6 +672,10 @@ class TVMArgsSetter {
values_[i].v_handle = &value;
type_codes_[i] = kFuncHandle;
}
void operator()(size_t i, Module& value) const { // NOLINT(*)
values_[i].v_handle = &value;
type_codes_[i] = kModuleHandle;
}
void operator()(size_t i, TVMRetValue& value) const { // NOLINT(*)
if (value.type_code() == kStr) {
values_[i].v_str = value.ptr<std::string>()->c_str();
......
/*!
* Copyright (c) 2017 by Contributors
* \file registry.h
* \brief This file defines the TVM global function registry.
*
* The registered functions will be made available to front-end
* as well as backend users.
*
* The registry stores type-erased functions.
* Each registered function is automatically exposed
* to front-end language(e.g. python).
*
* Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end.
*
* \code
* // register the function as MyAPIFuncName
* TVM_REGISTER_GLOBAL(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#ifndef TVM_RUNTIME_REGISTRY_H_
#define TVM_RUNTIME_REGISTRY_H_
#include <string>
#include <vector>
#include "./packed_func.h"
namespace tvm {
namespace runtime {
/*! \brief Registry for global function */
class Registry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
Registry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
Registry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
*/
static Registry& Register(const std::string& name); // NOLINT(*)
/*!
* \brief Erase global function from registry, if exist.
* \param name The name of the function.
* \return Whether function exist.
*/
static bool Remove(const std::string& name);
/*!
* \brief Get the global function by name.
* \param name The name of the function.
* \return pointer to the registered function,
* nullptr if it does not exist.
*/
static const PackedFunc* Get(const std::string& name); // NOLINT(*)
/*!
* \brief Get the names of currently registered global function.
* \return The names
*/
static std::vector<std::string> ListNames();
private:
/*! \brief name of the function */
std::string name_;
/*! \brief internal packed function */
PackedFunc func_;
// Internal class.
struct Manager;
friend struct Manager;
};
/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define TVM_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define TVM_ATTRIBUTE_UNUSED
#endif
/*!
* \brief Register a function globally.
* \code
* TVM_REGISTER_GLOBAL(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_GLOBAL(OpName) \
static TVM_ATTRIBUTE_UNUSED ::tvm::runtime::Registry& \
__make_TVMRegistry_ ## OpName = \
::tvm::runtime::Registry::Register(#OpName)
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_REGISTRY_H_
......@@ -12,6 +12,7 @@ from . import ir_pass
from . import codegen
from . import collections
from . import schedule
from . import module
from . import ndarray as nd
from .ndarray import cpu, gpu, opencl, init_opencl, cl
......
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches
"""Symbolic configuration API."""
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement
"""Function configuration API."""
from __future__ import absolute_import
import ctypes
......@@ -16,6 +16,7 @@ from ._node import NodeBase, SliceBase, convert_to_node
from ._ndarray import NDArrayBase
FunctionHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
def _ctypes_free_resource(rhandle):
......@@ -110,6 +111,9 @@ def _make_tvm_args(args, temp_args):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.NODE_HANDLE
temp_args.append(arg)
elif isinstance(arg, ModuleBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, Function):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.FUNC_HANDLE
......@@ -158,6 +162,102 @@ class Function(object):
return RETURN_SWITCH[ret_tcode.value](ret_val)
class ModuleBase(object):
"""Base class for module"""
__slots__ = ["handle", "_entry"]
def __init__(self, handle):
self.handle = handle
self._entry = None
@property
def entry_func(self):
"""Get the entry function
Returns
-------
f : Function
The entry function if exist
"""
if self._entry:
return self._entry
else:
self._entry = self.get_function("__tvm_main__")
return self._entry
def get_function(self, name, query_imports=False):
"""Get function from the module.
Parameters
----------
name : str
The name of the function
query_imports : bool
Whether also query modules imported by this module.
Returns
-------
f : Function
The result function.
"""
ret_handle = FunctionHandle()
check_call(_LIB.TVMModGetFunction(
self.handle, c_str(name),
ctypes.c_int(query_imports),
ctypes.byref(ret_handle)))
if not ret_handle.value:
raise AttributeError(
"Module has no function '%s'" % name)
return Function(ret_handle)
def import_module(self, module):
"""Add module to the import list of current one.
Parameters
----------
module : Module
The other module.
"""
check_call(_LIB.TVMModImport(self.handle, module.handle))
def precompile(self, func_name, ctx):
"""Add module to the import list of current one.
Parameters
----------
func_name : str
The name of function to be precompiled.
ctx : Context
The context to be precompiled.
"""
check_call(_LIB.TVMModPreCompile(
self.handle, c_str(func_name), ctx))
def __getitem__(self, name):
if not isinstance(name, string_types):
raise ValueError("Can only take string as function name")
return self.get_function(name)
def __del__(self):
check_call(_LIB.TVMModFree(self.handle))
def __call__(self, *args):
if self._entry:
return self._entry(*args)
else:
f = self.entry_func
return f(*args)
_module_cls = None
def _return_module(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, ModuleHandle):
handle = ModuleHandle(handle)
return _module_cls(handle)
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
......@@ -167,6 +267,8 @@ def _handle_return_func(x):
# setup return handle for function type
RETURN_SWITCH[TypeCode.FUNC_HANDLE] = _handle_return_func
RETURN_SWITCH[TypeCode.MODULE_HANDLE] = _return_module
def register_func(func_name, f=None):
"""Register global function
......@@ -248,6 +350,7 @@ def _init_api_functions(root_namespace):
"_arith_": sys.modules["%s.arith" % root_namespace],
"_pass_": sys.modules["%s.ir_pass" % root_namespace],
"_codegen_": sys.modules["%s.codegen" % root_namespace],
"_module_": sys.modules["%s.module" % root_namespace],
"_schedule_": sys.modules["%s.schedule" % root_namespace]
}
for name in list_global_func_names():
......@@ -259,3 +362,9 @@ def _init_api_functions(root_namespace):
target_module = v
f = get_global_func(name)
setattr(target_module, fname, f)
def _init_module_module(module_class):
"""Initialize the module."""
global _module_cls
_module_cls = module_class
......@@ -18,9 +18,10 @@ class TypeCode(object):
ARRAY_HANDLE = 5
TVM_TYPE = 6
NODE_HANDLE = 7
FUNC_HANDLE = 8
STR = 9
BYTES = 10
MODULE_HANDLE = 8
FUNC_HANDLE = 9
STR = 10
BYTES = 11
def _api_type(code):
"""create a type accepted by API"""
......
# pylint: disable=invalid-name
"""Util to compile with C++ code"""
from __future__ import absolute_import as _abs
import sys
import subprocess
def create_shared(path_target, objects,
options=None, cc="g++"):
"""Create shared library.
Parameters
----------
path_target : str
The target shared library.
objects : list
List of object files.
options : str
The additional options.
cc : str
The compile string.
"""
cmd = [cc]
cmd += ["-shared"]
cmd += ["-o", path_target]
cmd += objects
if options:
cmd += options
args = ' '.join(cmd)
proc = subprocess.Popen(
args, shell=True,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
(out, _) = proc.communicate()
if proc.returncode != 0:
sys.stderr.write("Compilation error:\n")
sys.stderr.write(out)
sys.stderr.flush()
# pylint: disable=invalid-name, too-many-locals
"""Util to compile with NVCC"""
from __future__ import absolute_import as _abs
import os
import sys
import tempfile
import subprocess
def compile_source(code, target="cubin", options=None):
def compile_source(code, target="ptx", arch=None,
options=None, path_target=None):
"""Compile cuda code with NVCC from env.
Parameters
......@@ -15,9 +18,15 @@ def compile_source(code, target="cubin", options=None):
target : str
The target format
arch : str
The architecture
options : str
The additional options
path_target : str, optional
Output file.
Return
------
cubin : bytearray
......@@ -26,18 +35,23 @@ def compile_source(code, target="cubin", options=None):
temp_dir = tempfile.mkdtemp()
if target not in ["cubin", "ptx", "fatbin"]:
raise ValueError("target must be in cubin, ptx, fatbin")
path_code = os.path.join(temp_dir, "my_kernel.cu")
path_target = os.path.join(temp_dir, "my_kernel.%s" % target)
temp_code = os.path.join(temp_dir, "my_kernel.cu")
temp_target = os.path.join(temp_dir, "my_kernel.%s" % target)
with open(path_code, "w") as out_file:
with open(temp_code, "w") as out_file:
out_file.write(code)
if target == "cubin" and arch is None:
raise ValueError("arch(sm_xy) must be passed for generating cubin")
file_target = path_target if path_target else temp_target
cmd = ["nvcc"]
cmd += ["--%s" % target, "-O3"]
cmd += ["-o", path_target]
cmd += ["-arch", arch]
cmd += ["-o", file_target]
if options:
cmd += options
cmd += [path_code]
cmd += [temp_code]
args = ' '.join(cmd)
proc = subprocess.Popen(
......@@ -52,9 +66,9 @@ def compile_source(code, target="cubin", options=None):
sys.stderr.flush()
cubin = None
else:
cubin = bytearray(open(path_target, "rb").read())
os.remove(path_code)
if os.path.exists(path_target):
os.remove(path_target)
cubin = bytearray(open(file_target, "rb").read())
os.remove(temp_code)
if os.path.exists(temp_target):
os.remove(temp_target)
os.rmdir(temp_dir)
return cubin
# pylint: disable=protected-access, no-member
"""Arithmetic data structure and utility"""
from __future__ import absolute_import as _abs
from ._ctypes._node import NodeBase, register_node
from . import _api_internal
......@@ -15,6 +16,7 @@ class IntSet(NodeBase):
"""Whether the set represent everything"""
return _api_internal._IntSetIsEverything(self)
@register_node
class IntervalSet(IntSet):
"""Represent set of continuous interval"""
......@@ -26,8 +28,8 @@ class IntervalSet(IntSet):
"""get the maximum value"""
return _api_internal._IntervalSetGetMax(self)
@register_node
class StrideSet(IntSet):
"""Represent set of strided integers"""
pass
......@@ -15,9 +15,9 @@ from . import codegen
def build(sch,
args,
target,
target_host="stackvm",
name="default_function",
binds=None,
record_codes=None,
max_auto_unroll_step=8):
"""Build a function with arguments as signiture.
......@@ -32,6 +32,9 @@ def build(sch,
target : str
The target of the compilation.
target_host :
Host compilation target, if target is device.
name : str
The name of result function.
......@@ -74,22 +77,17 @@ def build(sch,
stmt = ir_pass.LiftAllocate(stmt)
stmt = ir_pass.UnrollLoop(stmt, max_auto_unroll_step)
stmt = ir_pass.Simplify(stmt)
fapi = ir_pass.MakeAPI(stmt, name, arg_list, len(arg_list))
fapi = ir_pass.MakeAPI(stmt, name, arg_list, 0)
fsplits = ir_pass.SplitHostDevice(fapi)
fsplits = [x for x in fsplits]
for i in range(1, len(fsplits)):
fsplits[i] = ir_pass.StorageSync(fsplits[i], "shared")
if record_codes is not None:
output_ssa = False
for i, f in enumerate(fsplits):
t = target if i >= 1 else "c"
record_codes.append(codegen.CompileToC(f, output_ssa, t))
if target == "cuda":
ret = codegen.BuildNVRTC(fsplits, "stackvm")
elif target == "opencl":
ret = codegen.BuildOpenCL(fsplits, "stackvm")
if len(fsplits) > 1:
mhost = codegen.build(fsplits[0], target_host)
if target:
mdev = codegen.build(fsplits[1:], target)
mhost.import_module(mdev)
return mhost
else:
raise ValueError("Unknown target %s" % target)
return ret
return codegen.build(fsplits[0], target)
"""Runtime module related stuffs"""
# pylint: disable=unused-import, invalid-name, undefined-variable
from __future__ import absolute_import as _abs
from ._ctypes._function import ModuleBase, _init_module_module
class Module(ModuleBase):
"""Module container of all TVM generated functions"""
def __repr__(self):
return "Module(%s, %x)" % (self.type_key, self.handle.value)
@property
def type_key(self):
"""Get type key of the module."""
return _GetTypeKey(self)
def get_source(self, fmt=""):
"""Get source code from module, if available.
Parameters
----------
fmt : str, optional
The specified format.
"""
return _GetSource(self, fmt)
@property
def imported_modules(self):
"""Get imported modules
Returns
----------
modules : list of Modules
The module
"""
nmod = ImportsSize(self)
return [_GetImport(self, i) for i in range(nmod)]
def save(self, file_name, fmt=""):
"""Save the module to file.
Parameters
----------
file_name : str
The name of the file.
fmt : str
The format of the file.
"""
_SaveToFile(self, file_name, fmt)
def load(path, fmt=""):
"""Load module from file
Parameters
----------
path : str
The path to the module file.
fmt : str, optional
The format of the file, if not specified
it will be inferred from suffix of the file.
"""
return _LoadFromFile(path, fmt)
_init_module_module(Module)
......@@ -6,51 +6,24 @@
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <tvm/api_registry.h>
#include "../codegen/codegen_c.h"
#include "../codegen/codegen_cuda.h"
#include "../codegen/codegen_opencl.h"
namespace tvm {
namespace codegen {
TVM_REGISTER_API(_codegen_CompileToC)
TVM_REGISTER_API(_codegen_build)
.set_body([](TVMArgs args, TVMRetValue *ret) {
std::string mode = "c";
if (args.size() > 2) {
mode = args[2].operator std::string();
}
if (mode == "c") {
*ret = CodeGenC().Compile(args[0], args[1]);
} else if (mode == "cuda") {
*ret = CodeGenCUDA().Compile(args[0], args[1]);
} else if (mode == "opencl") {
*ret = CodeGenOpenCL().Compile(args[0], args[1]);
if (args[0].IsNodeType<LoweredFunc>()) {
*ret = Build({args[0]}, args[1]);
} else {
LOG(FATAL) << "cannot recognize mode";
*ret = Build(args[0], args[1]);
}
});
TVM_REGISTER_API(_codegen_BuildStackVM)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildStackVM(args[0],
std::unordered_map<LoweredFunc, PackedFunc>());
});
TVM_REGISTER_API(_codegen_BuildLLVM)
TVM_REGISTER_API(_codegen_target_enabled)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildLLVM(args[0]);
*ret = TargetEnabled(args[0]);
});
TVM_REGISTER_API(_codegen_BuildNVRTC)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildNVRTC(args[0], args[1]);
});
TVM_REGISTER_API(_codegen_BuildOpenCL)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = BuildOpenCL(args[0], args[1]);
});
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file api_registry.cc
*/
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
#include <memory>
namespace tvm {
struct APIManager {
std::unordered_map<std::string, std::unique_ptr<APIRegistry> > fmap;
static APIManager* Global() {
static APIManager inst;
return &inst;
}
};
APIRegistry& APIRegistry::__REGISTER__(const std::string& name) { // NOLINT(*)
APIManager* m = APIManager::Global();
CHECK(!m->fmap.count(name))
<< "API function " << name << " has already been registered";
std::unique_ptr<APIRegistry> p(new APIRegistry());
p->name_ = name;
m->fmap[name] = std::move(p);
return *(m->fmap[name]);
}
APIRegistry& APIRegistry::set_body(PackedFunc f) { // NOLINT(*)
PackedFunc::RegisterGlobal(name_, f);
return *this;
}
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file nvrtc.cc
* Build cuda modules from source.
* \file build_cuda.cc
*/
#include "./cuda_common.h"
#include <tvm/base.h>
#include <tvm/runtime/config.h>
#include "./codegen_cuda.h"
#if TVM_CUDA_RUNTIME
#include <nvrtc.h>
#include "../runtime/meta_data.h"
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"
namespace tvm {
namespace runtime {
namespace codegen {
#define NVRTC_CALL(x) \
{ \
......@@ -41,6 +47,46 @@ std::string NVRTCCompile(const std::string& code) {
return ptx;
}
} // namespace runtime
runtime::Module BuildCUDA(Array<LoweredFunc> funcs) {
using tvm::runtime::Registry;
bool output_ssa = false;
CodeGenCUDA cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
code = (*f)(code).operator std::string();
}
std::string ptx;
if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
ptx = (*f)(code).operator std::string();
} else {
ptx = NVRTCCompile(code);
}
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
}
fmap[f->name] = info;
}
return CUDAModuleCreate(ptx, "ptx", fmap, code);
}
TVM_REGISTER_API(_codegen_build_cuda)
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildCUDA(args[0]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_CUDA_RUNTIME
#endif // TVM_CUDA_RUNTIME
/*!
* Copyright (c) 2017 by Contributors
* Build opencl modules from source.
* \file build_opencl.cc
*/
#include <tvm/base.h>
#include <tvm/runtime/config.h>
#include "./codegen_opencl.h"
#if TVM_OPENCL_RUNTIME
#include "../runtime/meta_data.h"
#include "../runtime/opencl/opencl_common.h"
#include "../runtime/opencl/opencl_module.h"
namespace tvm {
namespace codegen {
runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
std::ostringstream os;
bool output_ssa = false;
CodeGenOpenCL cg;
cg.Init(output_ssa);
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
std::string code = cg.Finish();
std::unordered_map<std::string, runtime::FunctionInfo> fmap;
for (LoweredFunc f : funcs) {
runtime::FunctionInfo info;
for (size_t i = 0; i < f->args.size(); ++i) {
info.arg_types.push_back(Type2TVMType(f->args[i].type()));
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
info.thread_axis_tags.push_back(f->thread_axis[i]->thread_tag);
}
fmap[f->name] = info;
}
return OpenCLModuleCreate(code, "cl", fmap);
}
TVM_REGISTER_API(_codegen_build_opencl)
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = BuildOpenCL(args[0]);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_OPENCL_RUNTIME
/*!
* Copyright (c) 2017 by Contributors
* \file codegen.cc
* \brief Common utilities to generated C style code.
*/
#include <tvm/codegen.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
namespace tvm {
namespace codegen {
runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target) {
std::string mode = target;
size_t pos = mode.find("-");
if (pos != std::string::npos) {
mode = mode.substr(0, pos);
}
std::string build_f_name = "_codegen_build_" + mode;
const PackedFunc* bf = runtime::Registry::Get(build_f_name);
CHECK(bf != nullptr)
<< "Target " << target << " is not enabled";
runtime::Module m = (*bf)(funcs, target);
return m;
}
bool TargetEnabled(const std::string& target) {
std::string build_f_name = "_codegen_build_" + target;
return runtime::Registry::Get(build_f_name) != nullptr;
}
} // namespace codegen
} // namespace tvm
......@@ -12,9 +12,22 @@ namespace codegen {
using namespace ir;
std::string CodeGenC::Compile(LoweredFunc f,
bool output_ssa) {
void CodeGenC::Init(bool output_ssa) {
print_ssa_form_ = output_ssa;
}
void CodeGenC::InitFuncState(LoweredFunc f) {
alloc_storage_scope_.clear();
name_alloc_map_.clear();
ssa_assign_map_.clear();
var_idmap_.clear();
handle_data_type_.clear();
scope_mark_.clear();
}
void CodeGenC::AddFunction(LoweredFunc f) {
// clear previous generated state.
this->InitFuncState(f);
// skip the first underscore, so SSA variable starts from _1
GetUniqueName("_");
// add to alloc buffer type.
......@@ -47,7 +60,10 @@ std::string CodeGenC::Compile(LoweredFunc f,
this->PrintStmt(f->body);
this->EndScope(func_scope);
this->PrintIndent();
this->stream << "}\n";
this->stream << "}\n\n";
}
std::string CodeGenC::Finish() {
return stream.str();
}
......
......@@ -24,14 +24,20 @@ namespace codegen {
class CodeGenC {
public:
/*!
* \brief Generate the C code of statement
* \param f The function to be compiled
* \param output_ssa Whether output ssa form.
* \note Only call compile once,
* create a new codegen object each time.
* \brief Initialize the code generator.
* \param output_ssa Whether output SSA.
*/
std::string Compile(LoweredFunc f,
bool output_ssa);
void Init(bool output_ssa);
/*!
* \brief Add the function to the generated module.
* \param f The function to be compiled.
*/
void AddFunction(LoweredFunc f);
/*!
* \brief Finalize the compilation and return the code.
* \return The code.
*/
std::string Finish();
/*!
* \brief Print the Stmt n to CodeGenC->stream
* \param n The statement to be printed.
......@@ -74,6 +80,11 @@ class CodeGenC {
std::string GetVarID(const Variable* v) const;
// The following parts are overloadable print operations.
/*!
* \brief Initialize codegen state for generating f.
* \param f The function to be compiled.
*/
virtual void InitFuncState(LoweredFunc f);
/*!
* Print Type represetnation of type t.
* \param t The type representation.
* \param os The stream to print the ctype into
......@@ -182,7 +193,6 @@ class CodeGenC {
* \return The storage scope.
*/
std::string GetStorageScope(const Variable* buf_var) const;
/*! \brief the storage scope of allocation */
std::unordered_map<const Variable*, std::string> alloc_storage_scope_;
......
......@@ -4,23 +4,19 @@
*/
#include <tvm/base.h>
#include <tvm/runtime/config.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <vector>
#include <string>
#include "./codegen_cuda.h"
#include "./codegen_stack_vm.h"
#include "../arithmetic/compute_expr.h"
#include "../runtime/cuda/cuda_common.h"
#include "../runtime/cuda/cuda_module.h"
namespace tvm {
namespace codegen {
std::string CodeGenCUDA::Compile(
LoweredFunc f,
bool output_ssa) {
void CodeGenCUDA::AddFunction(LoweredFunc f) {
this->stream << "extern \"C\" __global__ ";
return CodeGenC::Compile(f, output_ssa);
CodeGenC::AddFunction(f);
}
void CodeGenCUDA::PrintStmt(const ir::For* op) {
......@@ -152,70 +148,5 @@ void CodeGenCUDA::PrintStorageScope(
os << "__shared__ ";
}
}
#if TVM_CUDA_RUNTIME
std::unordered_map<LoweredFunc, PackedFunc>
MakeNVRTC(Array<LoweredFunc> funcs) {
std::ostringstream os;
bool output_ssa = false;
for (LoweredFunc f : funcs) {
os << CodeGenCUDA().Compile(f, output_ssa);
os << '\n';
}
std::string code = os.str();
if (PackedFunc::GlobalExist("tvm_callback_cuda_postproc")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_postproc");
code = f(code).operator std::string();
}
std::string ptx;
if (PackedFunc::GlobalExist("tvm_callback_cuda_compile")) {
const auto& f = PackedFunc::GetGlobal("tvm_callback_cuda_compile");
ptx = f(code).operator std::string();
} else {
ptx = runtime::NVRTCCompile(os.str());
}
std::unordered_map<LoweredFunc, PackedFunc> ret;
runtime::CUDAModule m = runtime::CUDAModule::Create(ptx);
for (LoweredFunc f : funcs) {
std::vector<TVMType> arg_types(f->args.size());
std::vector<std::string> thread_axis_tags(f->thread_axis.size());
for (size_t i = 0; i < f->args.size(); ++i) {
arg_types[i] = Type2TVMType(f->args[i].type());
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
thread_axis_tags[i] = f->thread_axis[i]->thread_tag;
}
ret[f] = m.GetPackedFunc(f->name, arg_types, thread_axis_tags);
}
return ret;
}
PackedFunc BuildNVRTC(Array<LoweredFunc> fsplits, std::string host_mode) {
Array<LoweredFunc> device_list(fsplits.begin() + 1, fsplits.end());
std::unordered_map<LoweredFunc, PackedFunc> device_funcs = MakeNVRTC(device_list);
if (host_mode == "stackvm") {
StackVM vm = codegen::CodeGenStackVM().Compile(fsplits[0], device_funcs);
auto f = [vm](TVMArgs args, TVMRetValue* rv) {
runtime::AutoSetCUDADevice(args);
vm(args);
};
return PackedFunc(f);
} else {
LOG(FATAL) << "unknown host mode " << host_mode;
return PackedFunc();
}
}
#else
// dummy function when cuda is not available
PackedFunc BuildNVRTC(Array<LoweredFunc> func, std::string host_mode) {
LOG(FATAL) << "CUDA is not enabled";
return PackedFunc();
}
#endif // TVM_CUDA_RUNTIME
} // namespace codegen
} // namespace tvm
......@@ -16,16 +16,7 @@ namespace codegen {
class CodeGenCUDA : public CodeGenC {
public:
/*!
* \brief Generate the C code of statement
* \param f The function to be compiled
* \param output_ssa Whether output ssa form.
* \note Only call compile once,
* create a new codegen object each time.
*/
std::string Compile(LoweredFunc f,
bool output_ssa);
void AddFunction(LoweredFunc f);
// override behavior
void PrintStmt(const ir::For* op) final;
void PrintStorageSync(const std::string& sync) final;
......
......@@ -7,24 +7,23 @@
#include <vector>
#include <string>
#include "./codegen_opencl.h"
#include "./codegen_stack_vm.h"
#include "../runtime/opencl/opencl_common.h"
#include "../runtime/opencl/opencl_module.h"
#include "../runtime/thread_storage_scope.h"
namespace tvm {
namespace codegen {
std::string CodeGenOpenCL::Compile(
LoweredFunc f,
bool output_ssa) {
this->stream << " __kernel ";
void CodeGenOpenCL::InitFuncState(LoweredFunc f) {
CodeGenC::InitFuncState(f);
for (Var arg : f->args) {
if (arg.type().is_handle()) {
alloc_storage_scope_[arg.get()] = "global";
}
}
return CodeGenC::Compile(f, output_ssa);
}
void CodeGenOpenCL::AddFunction(LoweredFunc f) {
this->stream << " __kernel ";
CodeGenC::AddFunction(f);
}
void CodeGenOpenCL::PrintThreadIndexExpr(
......@@ -129,56 +128,5 @@ void CodeGenOpenCL::PrintStorageScope(const std::string& scope, std::ostream& os
os << "__local ";
}
}
#if TVM_OPENCL_RUNTIME
std::unordered_map<LoweredFunc, PackedFunc>
MakeOpenCL(Array<LoweredFunc> funcs) {
std::ostringstream os;
bool output_ssa = false;
for (LoweredFunc f : funcs) {
os << CodeGenOpenCL().Compile(f, output_ssa);
os << '\n';
}
std::unordered_map<LoweredFunc, PackedFunc> ret;
runtime::OpenCLModule m =
runtime::OpenCLModule::CreateWithSource(os.str());
for (LoweredFunc f : funcs) {
std::vector<TVMType> arg_types(f->args.size());
std::vector<std::string> thread_axis_tags(f->thread_axis.size());
for (size_t i = 0; i < f->args.size(); ++i) {
arg_types[i] = Type2TVMType(f->args[i].type());
}
for (size_t i = 0; i < f->thread_axis.size(); ++i) {
thread_axis_tags[i] = f->thread_axis[i]->thread_tag;
}
ret[f] = m.GetPackedFunc(f->name, arg_types, thread_axis_tags);
}
return ret;
}
PackedFunc BuildOpenCL(Array<LoweredFunc> fsplits, std::string host_mode) {
Array<LoweredFunc> device_list(fsplits.begin() + 1, fsplits.end());
std::unordered_map<LoweredFunc, PackedFunc> device_funcs = MakeOpenCL(device_list);
if (host_mode == "stackvm") {
StackVM vm = codegen::CodeGenStackVM().Compile(fsplits[0], device_funcs);
auto f = [vm](TVMArgs args, TVMRetValue* rv) {
runtime::AutoSetOpenCLContext(args);
vm(args);
};
return PackedFunc(f);
} else {
LOG(FATAL) << "unknown host mode " << host_mode;
return PackedFunc();
}
}
#else
// dummy function when opencl is not available
PackedFunc BuildOpenCL(Array<LoweredFunc> func, std::string host_mode) {
LOG(FATAL) << "OpenCL is not enabled";
return PackedFunc();
}
#endif // TVM_OPENCL_RUNTIME
} // namespace codegen
} // namespace tvm
......@@ -16,16 +16,9 @@ namespace codegen {
class CodeGenOpenCL : public CodeGenC {
public:
/*!
* \brief Generate the OpenCL code of statement
* \param f The function to be compiled
* \param output_ssa Whether output ssa form.
* \note Only call compile once,
* create a new codegen object each time.
*/
std::string Compile(LoweredFunc f,
bool output_ssa);
void AddFunction(LoweredFunc f);
// override print thread tag.
void InitFuncState(LoweredFunc f) final;
void PrintThreadIndexExpr(
std::string tag, std::ostream& os) final; // NOLINT(*)
void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
......
......@@ -12,6 +12,7 @@ namespace tvm {
namespace codegen {
void CodeGenLLVM::Init(const std::string& module_name,
const std::string& target_triple,
llvm::LLVMContext* ctx) {
InitializeLLVM();
static_assert(sizeof(TVMValue) == sizeof(double), "invariant");
......@@ -60,16 +61,37 @@ void CodeGenLLVM::Init(const std::string& module_name,
t_tvm_value_->getPointerTo(),
t_int_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get());
f_tvm_func_get_global_ = llvm::Function::Create(
f_tvm_get_func_from_env_ = llvm::Function::Create(
llvm::FunctionType::get(t_int_, {
t_void_p_,
t_char_->getPointerTo(),
t_tvm_func_handle_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMFuncGetGlobal", module_.get());
llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get());
f_tvm_api_set_last_error_ = llvm::Function::Create(
llvm::FunctionType::get(t_void_, {t_char_->getPointerTo()}, false),
llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get());
this->InitTarget(target_triple);
// initialize builder
builder_.reset(new IRBuilder(*ctx));
this->InitGlobalContext();
}
void CodeGenLLVM::InitTarget(const std::string& target) {
llvm::TargetMachine* tm;
std::string target_triple;
std::tie(tm, target_triple) = LLVMGetTarget(target);
module_->setTargetTriple(target_triple);
module_->setDataLayout(tm->createDataLayout());
data_layout_.reset(new llvm::DataLayout(module_.get()));
}
void CodeGenLLVM::InitGlobalContext() {
gv_mod_ctx_ = new llvm::GlobalVariable(
*module_, t_void_p_, false,
llvm::GlobalValue::LinkOnceODRLinkage, 0, "__tvm_module_ctx");
gv_mod_ctx_->setAlignment(data_layout_->getTypeAllocSize(t_void_p_));
gv_mod_ctx_->setInitializer(llvm::Constant::getNullValue(t_void_p_));
}
void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
......@@ -104,6 +126,24 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) {
builder_->CreateRet(ConstInt32(0));
}
void CodeGenLLVM::AddMainFunction(const std::string& entry_func_name) {
llvm::Function* f = module_->getFunction(entry_func_name);
CHECK(f) << "Function " << entry_func_name << "does not in module";
CHECK(!module_->getFunction(runtime::symbol::tvm_module_main));
llvm::FunctionType* ftype = f->getFunctionType();
function_ = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(runtime::symbol::tvm_module_main, ftype));
function_->setCallingConv(llvm::CallingConv::C);
std::vector<llvm::Value*> args;
for (auto it = function_->arg_begin();
it != function_->arg_end(); ++it) {
args.push_back(&(*it));
}
llvm::BasicBlock* block = llvm::BasicBlock::Create(*ctx_, "entry", function_);
builder_->SetInsertPoint(block);
builder_->CreateRet(builder_->CreateCall(f, args));
}
class FPassManager : public llvm::legacy::FunctionPassManager {
public:
explicit FPassManager(llvm::Module* m)
......@@ -364,14 +404,12 @@ void CodeGenLLVM::Visit_(const Load* op) {
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::DataLayout layout(module_.get());
uint64_t valign = layout.getTypeAllocSize(LLVMType(t));
llvm::LoadInst* inst = builder_->CreateAlignedLoad(
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
valign);
data_layout_->getTypeAllocSize(LLVMType(t)));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
value_ = inst;
} else {
......@@ -384,15 +422,13 @@ void CodeGenLLVM::Visit_(const Store* op) {
Type t = op->value.type();
CHECK(!t.is_vector());
if (t.is_scalar()) {
llvm::DataLayout layout(module_.get());
uint64_t valign = layout.getTypeAllocSize(value->getType());
llvm::StoreInst* inst = builder_->CreateAlignedStore(
value,
CreateBufferPtr(
t,
GetVarValue(op->buffer_var.get()),
MakeValue(op->index)),
valign);
data_layout_->getTypeAllocSize(value->getType()));
AddAliasInfo(inst, op->buffer_var.get(), op->index);
} else {
LOG(FATAL) << "not yet supported";
......@@ -400,8 +436,7 @@ void CodeGenLLVM::Visit_(const Store* op) {
}
void CodeGenLLVM::Visit_(const Call* op) {
if (op->is_intrinsic(intrinsic::tvm_call_global) ||
op->is_intrinsic(intrinsic::tvm_call_device)) {
if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
value_ = CreateCallPacked(op);
} else if (op->call_type == Call::Intrinsic ||
op->call_type == Call::PureIntrinsic) {
......@@ -734,14 +769,14 @@ llvm::Value* CodeGenLLVM::CreateCast(Type from, Type to, llvm::Value* value) {
}
}
llvm::Value* CodeGenLLVM::GetPackedFuncHandle(
const std::string& fname, bool global) {
llvm::Value* CodeGenLLVM::GetPackedFuncHandle(const std::string& fname) {
using llvm::BasicBlock;
// We will store the packed function handle in global space.
// Initialize it during the first call.
llvm::DataLayout layout(module_.get());
uint64_t halign = layout.getTypeAllocSize(t_tvm_func_handle_);
uint64_t align = layout.getTypeAllocSize(t_tvm_func_handle_);
auto it = func_handle_map_.find(fname);
llvm::GlobalVariable* hptr;
if (it == func_handle_map_.end()) {
// create global location for the handle
......@@ -749,7 +784,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(
hptr = new llvm::GlobalVariable(
*module_, t_tvm_func_handle_, false,
llvm::GlobalValue::PrivateLinkage, 0, ".tvm_func");
hptr->setAlignment(halign);
hptr->setAlignment(align);
hptr->setInitializer(llvm::Constant::getNullValue(t_tvm_func_handle_));
func_handle_map_[fname] = hptr;
} else {
......@@ -761,26 +796,19 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(
*ctx_, "handle_init", function_);
BasicBlock* end_block = BasicBlock::Create(
*ctx_, "handle_init_end", function_);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, halign);
llvm::Value* handle = builder_->CreateAlignedLoad(hptr, align);
llvm::Value* handle_not_null = builder_->CreateICmpNE(
handle, llvm::Constant::getNullValue(t_tvm_func_handle_));
builder_->CreateCondBr(
handle_not_null, end_block, init_block, md_very_likely_branch_);
// loaded handle, if created by call.
llvm::Value* loaded_handle = nullptr;
// Then block.
// We do not do lock here, so unlike static variable initialization
// This clause might be executed multiple times, but it is safe to do so.
// Initialize the handle if needed.
builder_->SetInsertPoint(init_block);
if (global) {
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_func_get_global_, {GetConstString(fname), out});
init_block = CheckPackedCallSuccess(retcode);
loaded_handle = builder_->CreateAlignedLoad(out, halign);
} else {
LOG(FATAL) << "not yet supported";
}
llvm::Value* out = builder_->CreateAlloca(t_tvm_func_handle_);
llvm::Value* ctx = builder_->CreateLoad(gv_mod_ctx_);
llvm::Value* retcode = builder_->CreateCall(
f_tvm_get_func_from_env_, {ctx, GetConstString(fname), out});
init_block = CheckPackedCallSuccess(retcode);
llvm::Value* loaded_handle = builder_->CreateAlignedLoad(out, align);
builder_->CreateBr(end_block);
// end block
builder_->SetInsertPoint(end_block);
......@@ -793,11 +821,7 @@ llvm::Value* CodeGenLLVM::GetPackedFuncHandle(
llvm::Value* CodeGenLLVM::CreateCallPacked(const Call* op) {
CHECK_GE(op->args.size(), 1U);
std::string func_name = op->args[0].as<StringImm>()->value;
CHECK(!op->is_intrinsic(intrinsic::tvm_call_device))
<< "not implemented for now";
llvm::Value* handle = GetPackedFuncHandle(
func_name, op->is_intrinsic(intrinsic::tvm_call_global));
llvm::Value* handle = GetPackedFuncHandle(func_name);
// call the function
unsigned nargs = static_cast<unsigned>(op->args.size() - 1);
llvm::Value* targs = builder_->CreateAlloca(
......
......@@ -28,15 +28,23 @@ class CodeGenLLVM : public IRVisitor {
/*!
* \brief Initialize the code generator with given context
* \param module_name The name of the module.
* \param target_triple The target triple, can be empty.
* \param ctx The context.
*/
void Init(const std::string& module_name, llvm::LLVMContext* ctx);
void Init(const std::string& module_name,
const std::string& target_triple,
llvm::LLVMContext* ctx);
/*!
* \brief Compile and add function f to the current module.
* \param f The function to be added.
*/
void AddFunction(const LoweredFunc& f);
/*!
* \brief Add main function as the entry name
* \param entry_func_name The name of entry function to be added.
*/
void AddMainFunction(const std::string& entry_func_name);
/*!
* \brief Finish current pass of codegen, get the module.
* \return the created module.
*/
......@@ -119,6 +127,7 @@ class CodeGenLLVM : public IRVisitor {
std::unique_ptr<IRBuilder> builder_;
// The module to be returned;
std::unique_ptr<llvm::Module> module_;
std::unique_ptr<llvm::DataLayout> data_layout_;
// Internal metabuilder
std::unique_ptr<llvm::MDBuilder> md_builder_;
// llvm context
......@@ -145,7 +154,7 @@ class CodeGenLLVM : public IRVisitor {
llvm::StructType* t_tvm_value_{nullptr};
// tvm api functions
llvm::Function* f_tvm_func_call_{nullptr};
llvm::Function* f_tvm_func_get_global_{nullptr};
llvm::Function* f_tvm_get_func_from_env_{nullptr};
llvm::Function* f_tvm_api_set_last_error_{nullptr};
// The acting body
llvm::BasicBlock* block_{nullptr};
......@@ -166,17 +175,23 @@ class CodeGenLLVM : public IRVisitor {
llvm::Value* GetConstString(const std::string& str);
llvm::Value* CreateBufferPtr(Type t, llvm::Value* buffer, llvm::Value* index);
llvm::Value* CreateCast(Type from, Type to, llvm::Value* value);
llvm::Value* GetPackedFuncHandle(const std::string& str, bool global);
llvm::Value* GetPackedFuncHandle(const std::string& str);
// Check if the call to packed function is successful
// if not directly finalize function and pass on return code.
// return the end block after the check
llvm::BasicBlock* CheckPackedCallSuccess(llvm::Value* retcode);
// Initialize target
void InitTarget(const std::string& target);
// Add a function to set global module context
void InitGlobalContext();
// add alias information.
void AddAliasInfo(llvm::Instruction* load, const Variable* buffer, Expr index);
// The definition of local variable.
std::unordered_map<const Variable*, llvm::Value*> var_map_;
// global strings
std::unordered_map<std::string, llvm::Constant*> str_map_;
// The local module_context
llvm::GlobalVariable* gv_mod_ctx_{nullptr};
// global to packed function handle
std::unordered_map<std::string, llvm::GlobalVariable*> func_handle_map_;
};
......
......@@ -13,7 +13,7 @@ namespace codegen {
struct LLVMEnv {
std::mutex mu;
bool native_initialized{false};
bool all_initialized{false};
static LLVMEnv* Global() {
static LLVMEnv inst;
......@@ -23,17 +23,47 @@ struct LLVMEnv {
void InitializeLLVM() {
LLVMEnv* e = LLVMEnv::Global();
if (!e->native_initialized) {
if (!e->all_initialized) {
std::lock_guard<std::mutex>(e->mu);
if (!e->native_initialized) {
e->native_initialized = true;
llvm::InitializeNativeTarget();
llvm::InitializeNativeTargetAsmPrinter();
llvm::InitializeNativeTargetAsmParser();
if (!e->all_initialized) {
e->all_initialized = true;
llvm::InitializeAllTargetInfos();
llvm::InitializeAllTargets();
llvm::InitializeAllTargetMCs();
llvm::InitializeAllAsmParsers();
llvm::InitializeAllAsmPrinters();
}
}
}
std::pair<llvm::TargetMachine*, std::string>
LLVMGetTarget(const std::string& target_str) {
// setup target triple
std::string target_triple;
CHECK_EQ(target_str.substr(0, 4), "llvm");
if (target_str.length() > 4) {
target_triple = target_str.substr(5, target_str.length() - 5);
} else {
target_triple = "";
}
if (target_triple.length() == 0 ||
target_triple == "default") {
target_triple = llvm::sys::getDefaultTargetTriple();
}
std::string err;
const llvm::Target* target =
llvm::TargetRegistry::lookupTarget(target_triple, err);
CHECK(target) << err << " target_triple=" << target_triple;
std::string cpu = "generic";
std::string features = "";
llvm::TargetOptions opt;
auto rmodel = llvm::Reloc::PIC_;
llvm::TargetMachine* tm =
target->createTargetMachine(target_triple, cpu, features, opt, rmodel);
return {tm, target_triple};
}
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -9,6 +9,8 @@
#include <llvm/ExecutionEngine/MCJIT.h>
#include <llvm/Bitcode/BitcodeWriter.h>
#include <llvm/IR/Value.h>
#include <llvm/IR/Argument.h>
#include <llvm/IR/BasicBlock.h>
......@@ -27,15 +29,16 @@
#include <llvm/Transforms/IPO/PassManagerBuilder.h>
#include <llvm/Transforms/IPO.h>
#include <llvm/Support/raw_ostream.h>
#include <llvm/Support/Casting.h>
#include <llvm/Support/TargetRegistry.h>
#include <llvm/Support/TargetSelect.h>
#include <llvm/Target/TargetMachine.h>
#include <llvm/Target/TargetOptions.h>
#include <utility>
#include <string>
extern "C" {
// Function signature for LLVM generated packed function.
typedef int (*LLVMPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace codegen {
......@@ -46,6 +49,14 @@ namespace codegen {
*/
void InitializeLLVM();
/*!
* \brief Get target machine from target_str string.
* \param target_str Target triple string, can have llvm- prefix, can be empty.
* \return Pair of target machine and target triple.
*/
std::pair<llvm::TargetMachine*, std::string>
LLVMGetTarget(const std::string& target_str);
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_exec_engine.cc
*/
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
#include "./llvm_common.h"
#include "./codegen_llvm.h"
namespace tvm {
namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
#ifdef TVM_LLVM_VERSION
// Environment to keep jit resources alive.
struct LLVMJITEnv {
std::shared_ptr<llvm::LLVMContext> ctx;
llvm::ExecutionEngine* ee{nullptr};
// constructor
LLVMJITEnv(std::shared_ptr<llvm::LLVMContext> ctx,
llvm::ExecutionEngine* ee)
: ctx(ctx), ee(ee) {
}
// destructor
~LLVMJITEnv() {
if (ee != nullptr) {
ee->runStaticConstructorsDestructors(true);
delete ee;
}
}
};
PackedFunc JITCompile(std::unique_ptr<llvm::Module> module,
std::shared_ptr<llvm::LLVMContext> ctx,
const std::string& func_name) {
llvm::EngineBuilder builder(std::move(module));
builder.setEngineKind(llvm::EngineKind::JIT);
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
std::shared_ptr<LLVMJITEnv> env = std::make_shared<LLVMJITEnv>(
ctx, builder.create());
CHECK(env->ee != nullptr);
auto* faddr = reinterpret_cast<LLVMPackedCFunc>(
env->ee->getFunctionAddress(func_name));
env->ee->runStaticConstructorsDestructors(false);
return PackedFunc([env, faddr](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK(ret == 0) << TVMGetLastError();
});
}
PackedFunc BuildLLVM(LoweredFunc func) {
InitializeLLVM();
// use one context per function.
std::shared_ptr<llvm::LLVMContext> ctx =
std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;
cg.Init(func->name, ctx.get());
cg.AddFunction(func);
std::unique_ptr<llvm::Module> m = cg.Finish();
return JITCompile(std::move(m), ctx, func->name);
}
#else
PackedFunc BuildLLVM(LoweredFunc func) {
LOG(FATAL) << "LLVM is not enabled";
return PackedFunc();
}
#endif // TVM_LLVM_VERSION
} // namespace codegen
} // namespace tvm
/*!
* Copyright (c) 2017 by Contributors
* \file llvm_module.cc
* \brief LLVM runtime module for TVM
*/
#ifdef TVM_LLVM_VERSION
#include <tvm/runtime/packed_func.h>
#include <tvm/codegen.h>
#include "./llvm_common.h"
#include "./codegen_llvm.h"
#include "../../runtime/file_util.h"
#include "../../runtime/meta_data.h"
namespace tvm {
namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
class LLVMModuleNode : public runtime::ModuleNode {
public:
~LLVMModuleNode() {
module_.reset();
if (ee_ != nullptr) {
ee_->runStaticConstructorsDestructors(true);
delete ee_;
}
}
const char* type_key() const {
return "llvm";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_);
BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name));
CHECK(faddr != nullptr)
<< "Failed to Precompile function " << name;
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (ee_ == nullptr) LazyInitJIT();
std::lock_guard<std::mutex> lock(mutex_);
BackendPackedCFunc faddr =
reinterpret_cast<BackendPackedCFunc>(ee_->getFunctionAddress(name));
if (faddr == nullptr) return PackedFunc();
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = runtime::GetFileFormat(file_name, format);
std::error_code ecode;
llvm::raw_fd_ostream dest(file_name, ecode, llvm::sys::fs::F_None);
CHECK_EQ(ecode.value(), 0) << "Cannot open file: " << file_name
<< " " << ecode.message();
if (fmt == "o" || fmt == "obj") {
llvm::legacy::PassManager pass;
CHECK(tm_);
CHECK(tm_->addPassesToEmitFile(
pass, dest, llvm::TargetMachine::CGFT_ObjectFile) == 0)
<< "Cannot emit target CGFT_ObjectFile";
pass.run(*mptr_);
} else if (fmt == "ll") {
mptr_->print(dest, nullptr);
} else if (fmt == "bc") {
llvm::WriteBitcodeToFile(mptr_, dest);
} else {
LOG(FATAL) << "Do not know how to save file "
<< file_name << " with format=\'"<< format << "\'";
}
dest.close();
}
std::string GetSource(const std::string& format) final {
std::string type_str;
llvm::raw_string_ostream rso(type_str);
CHECK(mptr_ != nullptr);
mptr_->print(rso, nullptr);
return rso.str();
}
void Init(const Array<LoweredFunc>& funcs, std::string target) {
InitializeLLVM();
std::tie(tm_, target_triple_) = LLVMGetTarget(target);
CHECK_NE(funcs.size(), 0U);
ctx_ = std::make_shared<llvm::LLVMContext>();
CodeGenLLVM cg;
cg.Init(funcs[0]->name, target, ctx_.get());
for (LoweredFunc f : funcs) {
cg.AddFunction(f);
}
cg.AddMainFunction(funcs[0]->name);
module_ = cg.Finish();
mptr_ = module_.get();
}
private:
void LazyInitJIT() {
CHECK(ee_ == nullptr);
std::lock_guard<std::mutex> lock(mutex_);
std::string target_triple = mptr_->getTargetTriple();
llvm::EngineBuilder builder(std::move(module_));
builder.setEngineKind(llvm::EngineKind::JIT);
builder.setOptLevel(llvm::CodeGenOpt::Aggressive);
llvm::TargetMachine *tm = builder.selectTarget();
llvm::DataLayout layout(tm->createDataLayout());
CHECK(layout == mptr_->getDataLayout())
<< "Data layout mismatch between module("
<< mptr_->getDataLayout().getStringRepresentation() << ")"
<< " and ExecutionEngine ("
<< layout.getStringRepresentation() << ")";
ee_ = builder.create(tm);
CHECK(ee_ != nullptr)
<< "Failed to initialize git engine for " << target_triple;
ee_->runStaticConstructorsDestructors(false);
// setup context address.
void** ctx_addr =
reinterpret_cast<void**>(
ee_->getGlobalValueAddress(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) {
*ctx_addr = this;
}
}
// The target configuration string
std::string target_triple_;
// JIT lock
std::mutex mutex_;
// execution engine
llvm::ExecutionEngine *ee_{nullptr};
// The raw pointer to the module.
llvm::Module* mptr_{nullptr};
// The target machine
llvm::TargetMachine* tm_{nullptr};
// The module, can be moved to ee if JIT is enabled.
std::unique_ptr<llvm::Module> module_;
// the context.
std::shared_ptr<llvm::LLVMContext> ctx_;
};
TVM_REGISTER_API(_codegen_build_llvm)
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<LLVMModuleNode> n = std::make_shared<LLVMModuleNode>();
n->Init(args[0], args[1]);
*rv = runtime::Module(n);
});
} // namespace codegen
} // namespace tvm
#endif // TVM_LLVM_VERSION
......@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file codegen_stack_vm.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <limits>
#include "./codegen_stack_vm.h"
......@@ -11,34 +12,16 @@ namespace codegen {
using namespace ir;
PackedFunc BuildStackVM(
LoweredFunc func,
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs) {
StackVM vm = codegen::CodeGenStackVM().Compile(func, device_funcs);
auto f = [vm](TVMArgs args, TVMRetValue* rv) {
vm(args);
};
return PackedFunc(f);
}
CodeGenStackVM::FType& CodeGenStackVM::vtable() { // NOLINT(*)
static FType inst; return inst;
}
StackVM CodeGenStackVM::Compile(
LoweredFunc f,
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs) {
StackVM CodeGenStackVM::Compile(LoweredFunc f) {
for (size_t i = 0; i < f->args.size(); ++i) {
Var v = f->args[i];
int vid = AllocVarID(v.get());
CHECK_EQ(static_cast<size_t>(vid), i);
}
// setup device function map
for (const auto& kv : device_funcs) {
int fid = static_cast<int>(vm_.packed_func.size());
vm_.packed_func.push_back(kv.second);
device_fun_idmap_[kv.first->name] = fid;
}
this->Push(f->body);
return std::move(vm_);
}
......@@ -194,7 +177,7 @@ void CodeGenStackVM::Push_(const ir::Call* op) {
case intrinsic::kTypeLanes: PushOp(StackVM::TVM_ARRAY_GET_TYPE_LANES); break;
default: LOG(FATAL) << "unknown field code";
}
} else if (op->is_intrinsic(intrinsic::tvm_call_global)) {
} else if (op->is_intrinsic(intrinsic::tvm_call_packed)) {
CHECK_GE(op->args.size(), 1U);
const StringImm* s = op->args[0].as<StringImm>();
CHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
......@@ -203,15 +186,14 @@ void CodeGenStackVM::Push_(const ir::Call* op) {
}
// find the fuction id.
const std::string& func_name = s->value;
auto it = global_fun_idmap_.find(func_name);
auto it = extern_fun_idmap_.find(func_name);
int fid;
if (it != global_fun_idmap_.end()) {
if (it != extern_fun_idmap_.end()) {
fid = it->second;
} else {
fid = static_cast<int>(vm_.packed_func.size());
PackedFunc f = PackedFunc::GetGlobal(func_name);
vm_.packed_func.push_back(f);
global_fun_idmap_[func_name] = fid;
fid = static_cast<int>(vm_.extern_func_name.size());
vm_.extern_func_name.push_back(func_name);
extern_fun_idmap_[func_name] = fid;
}
// get the argument type code.
std::vector<int> arg_type_codes;
......@@ -228,21 +210,6 @@ void CodeGenStackVM::Push_(const ir::Call* op) {
this->Push(op->args[0]);
this->PushOp(StackVM::PUSH_I64, 0);
this->PushOp(StackVM::EQ_I64);
} else if (op->is_intrinsic(intrinsic::tvm_call_device)) {
std::string func_name = op->args[0].as<StringImm>()->value;
auto it = device_fun_idmap_.find(func_name);
CHECK(it != device_fun_idmap_.end())
<< "Cannot find device function " << func_name;
const int fid = it->second;
std::vector<int> arg_type_codes;
for (size_t i = 1; i < op->args.size(); ++i) {
this->Push(op->args[i]);
Type t = op->args[i].type();
int lanes = t.lanes();
CHECK_EQ(lanes, 1);
arg_type_codes.push_back(t.code());
}
this->PushCallPacked(fid, arg_type_codes);
} else {
this->HandleUnknownCall(op);
}
......
......@@ -3,8 +3,8 @@
* \file codegen_stack_vm.h
* \brief Codegen into Simple Stack VM.
*/
#ifndef TVM_CODEGEN_CODEGEN_STACK_VM_H_
#define TVM_CODEGEN_CODEGEN_STACK_VM_H_
#ifndef TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_
#define TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_
#include <tvm/ir.h>
#include <tvm/lowered_func.h>
......@@ -13,13 +13,11 @@
#include <vector>
#include <unordered_map>
#include "../runtime/stack_vm/stack_vm.h"
#include "./stack_vm.h"
namespace tvm {
namespace codegen {
using runtime::StackVM;
/*!
* \brief A base class to generate a stack VM.
* This module is used to generate host wrapper
......@@ -34,9 +32,7 @@ class CodeGenStackVM {
* \note Only call compile once,
* create a new codegen object each time.
*/
StackVM Compile(
LoweredFunc f,
const std::unordered_map<LoweredFunc, PackedFunc>& device_funcs);
StackVM Compile(LoweredFunc f);
/*! \brief Push stmt to generate new code */
void Push(const Stmt& n);
/*! \brief Push expr to generate new code */
......@@ -108,11 +104,9 @@ class CodeGenStackVM {
/*! \brief id of each string */
std::unordered_map<std::string, int> str_idmap_;
/*! \brief id of each global function */
std::unordered_map<std::string, int> global_fun_idmap_;
/*! \brief id of device function */
std::unordered_map<std::string, int> device_fun_idmap_;
std::unordered_map<std::string, int> extern_fun_idmap_;
};
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_STACK_VM_H_
#endif // TVM_CODEGEN_STACK_VM_CODEGEN_STACK_VM_H_
......@@ -7,7 +7,7 @@
#include "./stack_vm.h"
namespace tvm {
namespace runtime {
namespace codegen {
typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore;
......@@ -182,6 +182,7 @@ void StackVM::operator()(const runtime::TVMArgs& args) const {
if (s->heap.size() < this->heap_size) {
s->heap.resize(this->heap_size);
}
s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*)
s->heap[2].v_int64 = args.num_args;
......@@ -193,7 +194,8 @@ void StackVM::Run(State* s) const {
int64_t pc = s->pc;
std::vector<TVMValue>& stack = s->stack;
std::vector<TVMValue>& heap = s->heap;
s->extern_func.clear();
s->extern_func.resize(extern_func_name.size());
if (stack.size() < stack_size) {
stack.resize(stack_size);
}
......@@ -287,7 +289,7 @@ void StackVM::Run(State* s) const {
alignof(Code) == alignof(int), "asusmption");
const int* type_codes = &(code[pc].v_int) + 3;
runtime::TVMRetValue rv;
packed_func[call_fid].CallPacked(
GetExtern(s, call_fid).CallPacked(
runtime::TVMArgs(&stack[sp + 1 - num_args], type_codes, num_args), &rv);
sp = sp + 1 - num_args;
stack[sp] = rv.value();
......@@ -364,5 +366,18 @@ void StackVM::Run(State* s) const {
}
}
} // namespace runtime
const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
PackedFunc& f = s->extern_func[fid];
if (f == nullptr) {
CHECK(mod_ctx != nullptr)
<< "No local context is set in stackvm";
const PackedFunc* pf = mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
CHECK(pf != nullptr);
f = *pf;
CHECK(f != nullptr);
}
return f;
}
} // namespace codegen
} // namespace tvm
......@@ -7,17 +7,20 @@
* to setup calls into device functions
* when only Runtime compilation for device is available(via NVRTC or OpenCL).
*/
#ifndef TVM_RUNTIME_STACK_VM_STACK_VM_H_
#define TVM_RUNTIME_STACK_VM_STACK_VM_H_
#ifndef TVM_CODEGEN_STACK_VM_STACK_VM_H_
#define TVM_CODEGEN_STACK_VM_STACK_VM_H_
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/packed_func_ext.h>
#include <string>
#include <vector>
namespace tvm {
namespace runtime {
namespace codegen {
using runtime::operator<<;
/*!
* \brief A simple stack-based virtual machine.
*/
......@@ -209,11 +212,19 @@ class StackVM {
std::vector<TVMValue> stack;
/*! \brief The global heap space */
std::vector<TVMValue> heap;
/*! \brief extern functions */
std::vector<PackedFunc> extern_func;
/*! \brief stack pointer */
int64_t sp{0};
/*! \brief program counter */
int64_t pc{0};
};
/*! \brief The external function entries. */
struct ExternFuncEntry {
std::string name;
runtime::PackedFunc func;
};
/*! \brief execute the stack vm with given state */
void Run(State* state) const;
/*!
......@@ -229,9 +240,11 @@ class StackVM {
std::vector<Code> code;
/*! \brief constant error messages */
std::vector<std::string> str_data;
/*! \brief Extern functions in packed func format */
std::vector<runtime::PackedFunc> packed_func;
/*! \brief name of each heap id*/
/*! \brief The current module context of stackvm */
runtime::ModuleNode* mod_ctx{nullptr};
/*! \brief Extern functions */
std::vector<std::string> extern_func_name;
/*! \brief name of each heap id */
std::vector<std::string> heap_id_name;
/*! \brief The memory size needed */
size_t heap_size{0};
......@@ -296,8 +309,12 @@ class StackVM {
return ADDR_LOAD_FP64;
}
friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
private:
// get extern function.
const PackedFunc& GetExtern(State* s, int fid) const;
};
} // namespace runtime
} // namespace codegen
} // namespace tvm
#endif // TVM_RUNTIME_STACK_VM_STACK_VM_H_
#endif // TVM_CODEGEN_STACK_VM_STACK_VM_H_
/*!
* Copyright (c) 2017 by Contributors
* \file stack_vm_module.cc
*/
#include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h>
#include <tvm/codegen.h>
#include "./codegen_stack_vm.h"
namespace tvm {
namespace codegen {
class StackVMModuleNode : public runtime::ModuleNode {
public:
const char* type_key() const {
return "stackvm";
}
void PreCompile(const std::string& name, TVMContext ctx) final {}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
if (name == runtime::symbol::tvm_module_main) {
return GetFunction(entry_func_, sptr_to_self);
}
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const StackVM& vm = it->second;
// capture sptr_to_self to keep module node alive.
return PackedFunc([vm, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
vm(args);
});
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "StackVM do not support SaveToFile";
}
std::string GetSource(const std::string& format) final {
std::ostringstream os;
for (const auto& kv : fmap_) {
os << "Function: " << kv.first << '\n';
os << kv.second;
}
return os.str();
}
static runtime::Module Build(const Array<LoweredFunc>& funcs) {
CHECK_NE(funcs.size(), 0U);
std::shared_ptr<StackVMModuleNode> n =
std::make_shared<StackVMModuleNode>();
for (LoweredFunc f : funcs) {
StackVM vm = codegen::CodeGenStackVM().Compile(f);
CHECK(!n->fmap_.count(f->name))
<< "Function name " << f->name << "already exist in list";
vm.mod_ctx = n.get();
n->fmap_[f->name] = std::move(vm);
}
n->entry_func_ = funcs[0]->name;
return runtime::Module(n);
}
private:
// entry function.
std::string entry_func_;
// internal function map
std::unordered_map<std::string, StackVM> fmap_;
};
TVM_REGISTER_API(_codegen_build_stackvm)
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = StackVMModuleNode::Build(args[0]);
});
} // namespace codegen
} // namespace tvm
......@@ -35,9 +35,12 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr rhs, std::string msg) {
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
int num_packed_args) {
int num_unpacked_args) {
const Type tvm_index_type = UInt(32);
const Stmt nop = Evaluate::make(0);
int num_args = static_cast<int>(api_args.size());
CHECK_LE(num_unpacked_args, num_args);
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
// The packed fields
Var v_packed_args("args", Handle());
......@@ -182,6 +185,7 @@ LoweredFunc MakeAPI(Stmt body,
n->name = name;
n->args = args;
n->handle_data_type = handle_data_type;
n->is_packed_func = num_unpacked_args == 0;
n->body = MergeNest({seq_init, seq_check}, body);
LoweredFunc f(n);
Array<Var> undefined = UndefinedVars(f);
......
......@@ -7,6 +7,7 @@
#include <tvm/lowered_func.h>
#include <tvm/ir_pass.h>
#include <tvm/ir_mutator.h>
#include <tvm/runtime/module.h>
#include <unordered_map>
namespace tvm {
......@@ -154,6 +155,19 @@ class HostDeviceSplitter : public IRMutator {
std::make_shared<LoweredFuncNode>(*f.operator->());
n->body = this->Mutate(f->body);
if (f->is_packed_func && device_funcs_.size() != 0) {
// insert auto set device from device function.
Array<Expr> args = {StringImm::make(runtime::symbol::tvm_entry_setdevice)};
for (Var arg : f->args) {
args.push_back(arg);
}
n->body = Block::make(
Evaluate::make(Call::make(
Int(32), intrinsic::tvm_call_packed,
args, Call::Intrinsic)),
n->body);
}
Array<LoweredFunc> ret{LoweredFunc(n)};
for (LoweredFunc x : device_funcs_) {
ret.push_back(x);
......@@ -194,7 +208,7 @@ class HostDeviceSplitter : public IRMutator {
}
device_funcs_.emplace_back(f_device);
return Evaluate::make(Call::make(
Int(32), intrinsic::tvm_call_device,
Int(32), intrinsic::tvm_call_packed,
call_args, Call::Intrinsic));
}
......
......@@ -6,6 +6,7 @@
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <algorithm>
#include <string>
#include "./runtime_base.h"
......@@ -69,10 +70,136 @@ using namespace tvm::runtime;
struct TVMRuntimeEntry {
std::string ret_str;
std::string last_error;
};
typedef dmlc::ThreadLocalStore<TVMRuntimeEntry> TVMAPIRuntimeStore;
const char *TVMGetLastError() {
return TVMAPIRuntimeStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
TVMAPIRuntimeStore::Get()->last_error = msg;
}
int TVMModLoadFromFile(const char* file_name,
const char* format,
TVMModuleHandle* out) {
API_BEGIN();
Module m = Module::LoadFromFile(file_name, format);
*out = new Module(m);
API_END();
}
int TVMModImport(TVMModuleHandle mod,
TVMModuleHandle dep) {
API_BEGIN();
static_cast<Module*>(mod)->Import(
*static_cast<Module*>(dep));
API_END();
}
int TVMModGetFunction(TVMModuleHandle mod,
const char* func_name,
int query_imports,
TVMFunctionHandle *func) {
API_BEGIN();
PackedFunc pf = static_cast<Module*>(mod)->GetFunction(
func_name, query_imports);
if (pf != nullptr) {
*func = new PackedFunc(pf);
} else {
*func = nullptr;
}
API_END();
}
int TVMModPreCompile(TVMModuleHandle mod,
const char* func_name,
TVMContext ctx) {
API_BEGIN();
(*static_cast<Module*>(mod))->PreCompile(func_name, ctx);
API_END();
}
int TVMBackendGetFuncFromEnv(void* mod_node,
const char* func_name,
TVMFunctionHandle *func) {
API_BEGIN();
*func = (TVMFunctionHandle)(
static_cast<ModuleNode*>(mod_node)->GetFuncFromEnv(func_name));
API_END();
}
int TVMModFree(TVMModuleHandle mod) {
API_BEGIN();
delete static_cast<Module*>(mod);
API_END();
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue value,
int type_code) {
API_BEGIN();
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value, type_code);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) {
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
});
}
API_END();
}
int TVMDeviceInit(int dev_mask,
const char** option_keys,
const char** option_vals,
......@@ -175,65 +302,3 @@ int TVMSynchronize(TVMContext ctx, TVMStreamHandle stream) {
});
API_END();
}
int TVMFuncFree(TVMFunctionHandle func) {
API_BEGIN();
delete static_cast<PackedFunc*>(func);
API_END();
}
int TVMFuncCall(TVMFunctionHandle func,
TVMValue* args,
int* arg_type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code) {
API_BEGIN();
TVMRetValue rv;
(*static_cast<const PackedFunc*>(func)).CallPacked(
TVMArgs(args, arg_type_codes, num_args), &rv);
// handle return string.
if (rv.type_code() == kStr ||
rv.type_code() == kTVMType) {
TVMRuntimeEntry* e = TVMAPIRuntimeStore::Get();
e->ret_str = rv.operator std::string();
*ret_type_code = kStr;
ret_val->v_str = e->ret_str.c_str();
} else {
rv.MoveToCHost(ret_val, ret_type_code);
}
API_END();
}
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue value,
int type_code) {
API_BEGIN();
TVMRetValue* rv = static_cast<TVMRetValue*>(ret);
*rv = TVMArgValue(value, type_code);
API_END();
}
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out) {
API_BEGIN();
if (fin == nullptr) {
*out = new PackedFunc(
[func, resource_handle](TVMArgs args, TVMRetValue* rv) {
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, resource_handle);
});
} else {
// wrap it in a shared_ptr, with fin as deleter.
// so fin will be called when the lambda went out of scope.
std::shared_ptr<void> rpack(resource_handle, fin);
*out = new PackedFunc(
[func, rpack](TVMArgs args, TVMRetValue* rv) {
func((TVMValue*)args.values, (int*)args.type_codes, // NOLINT(*)
args.num_args, rv, rpack.get());
});
}
API_END();
}
......@@ -41,29 +41,6 @@ namespace runtime {
* \return The PTX code.
*/
std::string NVRTCCompile(const std::string& code);
/*!
* \brief Automatically detect and set cuda device.
* \param args The arguments.
*/
inline void AutoSetCUDADevice(const TVMArgs& args) {
int dev_id = -1;
for (int i = 0; i < args.size(); ++i) {
if (args.type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(
args.values[i].v_handle)->ctx;
CHECK_EQ(ctx.dev_mask, kGPU)
<< "All operands need to be GPU";
if (dev_id == -1) {
dev_id = ctx.dev_id;
} else {
CHECK_EQ(dev_id, ctx.dev_id)
<< "Operands comes from different devices ";
}
}
}
CUDA_CALL(cudaSetDevice(dev_id));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_CUDA_RUNTIME
......
......@@ -5,6 +5,7 @@
#include "./cuda_module.h"
#if TVM_CUDA_RUNTIME
#include <tvm/runtime/registry.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>
......@@ -14,20 +15,64 @@
#include "./cuda_common.h"
#include "../void_addr_args.h"
#include "../thread_storage_scope.h"
#include "../meta_data.h"
#include "../file_util.h"
namespace tvm {
namespace runtime {
/*!
* \brief Internal data structure to support multi-gpu execution.
* Try to use CUDA runtime's primary context.
*/
class CUDAModule::Internal {
// Module to support thread-safe multi-GPU execution.
// cuModule is a per-GPU module
// The runtime will contain a per-device module table
// The modules will be lazily loaded
class CUDAModuleNode : public runtime::ModuleNode {
public:
explicit Internal(std::string data)
: data_(data) {
explicit CUDAModuleNode(std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source)
: data_(data), fmt_(fmt), fmap_(fmap), cuda_source_(cuda_source) {
std::fill(module_.begin(), module_.end(), nullptr);
}
// destructor
~CUDAModuleNode() {
for (size_t i = 0; i < module_.size(); ++i) {
if (module_[i] != nullptr) {
CUDA_CALL(cudaSetDevice(i));
CUDA_DRIVER_CALL(cuModuleUnload(module_[i]));
}
}
}
const char* type_key() const final {
return "cuda";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
CUDA_CALL(cudaSetDevice(ctx.dev_id));
cudaFree(nullptr);
this->GetFunc(ctx.dev_id, name);
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "Not implemented";
}
std::string GetSource(const std::string& format) final {
if (format == fmt_) return data_;
if (cuda_source_.length() != 0) {
return cuda_source_;
} else {
if (fmt_ == "ptx") return data_;
return "";
}
}
// get a CUfunction from primary context in dev_id
CUfunction GetFunc(int dev_id, const std::string& func_name) {
std::lock_guard<std::mutex> lock(mutex_);
......@@ -46,21 +91,18 @@ class CUDAModule::Internal {
}
return func;
}
// destructor
~Internal() {
for (size_t i = 0; i < module_.size(); ++i) {
if (module_[i] != nullptr) {
CUDA_CALL(cudaSetDevice(i));
CUDA_DRIVER_CALL(cuModuleUnload(module_[i]));
}
}
}
private:
// the binary data
std::string data_;
// The format
std::string fmt_;
// function information table.
std::unordered_map<std::string, FunctionInfo> fmap_;
// The cuda source.
std::string cuda_source_;
// the internal modules per GPU, to be lazily initialized.
std::array<CUmodule, CUDAModule::kMaxNumGPUs> module_;
std::array<CUmodule, kMaxNumGPUs> module_;
// internal mutex when updating the module
std::mutex mutex_;
};
......@@ -69,11 +111,13 @@ class CUDAModule::Internal {
class CUDAWrappedFunc {
public:
// initialize the CUDA function.
void Init(std::shared_ptr<CUDAModule::Internal> m,
void Init(CUDAModuleNode* m,
std::shared_ptr<ModuleNode> sptr,
const std::string& func_name,
size_t num_void_args,
const std::vector<std::string>& thread_axis_tags) {
const std::vector<std::string>& thread_axis_tags) {
m_ = m;
sptr_ = sptr;
func_name_ = func_name;
std::fill(fcache_.begin(), fcache_.end(), nullptr);
thread_axis_cfg_.Init(num_void_args, thread_axis_tags);
......@@ -101,32 +145,87 @@ class CUDAWrappedFunc {
private:
// internal module
std::shared_ptr<CUDAModule::Internal> m_;
CUDAModuleNode* m_;
// the resource holder
std::shared_ptr<ModuleNode> sptr_;
// The name of the function.
std::string func_name_;
// Device function cache per device.
// mark as mutable, to enable lazy initialization
mutable std::array<CUfunction, CUDAModule::kMaxNumGPUs> fcache_;
mutable std::array<CUfunction, kMaxNumGPUs> fcache_;
// thread axis configuration
ThreadAxisConfig thread_axis_cfg_;
};
PackedFunc CUDAModule::GetPackedFunc(
const std::string& func_name,
const std::vector<TVMType> arg_types,
const std::vector<std::string> thread_axis_tags) const {
void AutoSetCUDADevice(const TVMArgs& args, TVMRetValue* rv) {
CHECK_EQ(args.size(), 3);
TVMValue* values = static_cast<TVMValue*>(args[0].operator void*());
int* type_codes = static_cast<int*>(args[1].operator void*());
int num_args = args[2].operator int();
int dev_id = -1;
for (int i = 0; i < num_args; ++i) {
if (type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(values[i].v_handle)->ctx;
CHECK_EQ(ctx.dev_mask, kGPU)
<< "All operands need to be GPU";
if (dev_id == -1) {
dev_id = ctx.dev_id;
} else {
CHECK_EQ(dev_id, ctx.dev_id)
<< "Operands comes from different devices ";
}
}
}
CHECK_NE(dev_id, -1)
<< "Cannot detect device id from list";
CUDA_CALL(cudaSetDevice(dev_id));
}
PackedFunc CUDAModuleNode::GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
CHECK_EQ(sptr_to_self.get(), this);
CHECK_NE(name, symbol::tvm_module_main)
<< "Device function do not have main";
if (name == symbol::tvm_entry_setdevice) {
return PackedFunc(AutoSetCUDADevice);
}
auto it = fmap_.find(name);
if (it == fmap_.end()) return PackedFunc();
const FunctionInfo& info = it->second;
CUDAWrappedFunc f;
f.Init(ptr_, func_name, arg_types.size(), thread_axis_tags);
return PackFromVoidAddrArgs(f, arg_types);
f.Init(this, sptr_to_self, name, info.arg_types.size(), info.thread_axis_tags);
return PackFromVoidAddrArgs(f, info.arg_types);
}
CUDAModule CUDAModule::Create(std::string ptx) {
// call a runtime API to make sure the context is created.
CUDAModule m;
m.ptr_ = std::make_shared<Internal>(ptx);
return m;
Module CUDAModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source) {
std::shared_ptr<CUDAModuleNode> n =
std::make_shared<CUDAModuleNode>(data, fmt, fmap, cuda_source);
return Module(n);
}
// Load module from module.
Module CUDAModuleLoad(const std::string& file_name,
const std::string& format) {
std::string fmt = GetFileFormat(file_name, format);
std::string data = LoadBinaryFile(file_name);
return CUDAModuleCreate(data, fmt, {{}}, std::string());
}
TVM_REGISTER_GLOBAL(_module_loadfile_cubin)
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]);
});
TVM_REGISTER_GLOBAL(_module_loadfile_ptx)
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]);
});
} // namespace runtime
} // namespace tvm
#endif // TVM_CUDA_RUNTIME
......@@ -7,44 +7,31 @@
#define TVM_RUNTIME_CUDA_CUDA_MODULE_H_
#include <tvm/runtime/config.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <memory>
#include <vector>
#include <string>
#include "../meta_data.h"
namespace tvm {
namespace runtime {
/*! \brief Maximum number of GPU supported in CUDAModule */
static constexpr const int kMaxNumGPUs = 32;
/*!
* \brief Handle execution of CUDA kernels as PackedFunc.
* It wraps around driver API to work with CUDA runtime API.
* \brief create a cuda module from data.
*
* \param data The module data, can be ptx, cubin
* \param fmt The format of the data, can be "ptx", "cubin"
* \param fmap The map function information map of each function.
* \param cuda_source Optional, cuda source file
*/
class CUDAModule {
public:
/*!
* \brief Get CUDA Kernel launch wrapped as PackedFunc
* \param func_name The name of the function.
* \param arg_types The type of each argument in the function.
* \param thread_axis_tags The tag sequence of the thread axis.
*/
PackedFunc GetPackedFunc(
const std::string& func_name,
const std::vector<TVMType> arg_types,
const std::vector<std::string> thread_axis_tags) const;
/*!
* \brief create a cuda module from data.
* \param data The module data.
*/
static CUDAModule Create(std::string data);
/*! \brief hidden internal data structure. */
class Internal;
/*! \brief Maximum number of GPU supported in CUDAModule */
static constexpr const int kMaxNumGPUs = 32;
private:
std::shared_ptr<Internal> ptr_;
};
Module CUDAModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string cuda_source);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file dso_module.cc
* \brief Module to load from dynamic shared library.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include "./meta_data.h"
#if defined(_WIN32)
#include <windows.h>
#else
#include <dlfcn.h>
#endif
namespace tvm {
namespace runtime {
// Module to load from dynamic shared libary.
// This is the default module TVM used for hostside AOT
class DSOModuleNode : public ModuleNode {
public:
~DSOModuleNode() {
if (lib_handle_) Unload();
}
const char* type_key() const {
return "dso";
}
void PreCompile(const std::string& name, TVMContext ctx) final {
GetFuncPtr(name);
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
BackendPackedCFunc faddr = GetFuncPtr(name);
if (faddr == nullptr) return PackedFunc();
return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) {
int ret = (*faddr)(
(void*)args.values, // NOLINT(*)
(int*)args.type_codes, // NOLINT(*)
args.num_args);
CHECK_EQ(ret, 0) << TVMGetLastError();
});
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
LOG(FATAL) << "Cannot save dso to another file";
}
std::string GetSource(const std::string& format) final {
return "";
}
void Init(const std::string& name) {
Load(name);
CHECK(lib_handle_ != nullptr)
<< "Failed to load dynamic shared library " << name;
void** ctx_addr =
reinterpret_cast<void**>(
GetGlobalVPtr(runtime::symbol::tvm_module_ctx));
if (ctx_addr != nullptr) {
*ctx_addr = this;
}
}
private:
// Platform dependent handling.
#if defined(_WIN32)
// library handle
HMODULE lib_handle_{nullptr};
// Load the library
void Load(const std::string& name) {
lib_handle_ = LoadLibrary(name.c_str());
}
BackendPackedCFunc GetFuncPtr(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
GetProcAddress(lib_handle_, name.c_str())); // NOLINT(*)
}
void* GetGlobalVPtr(const std::string& name) {
return reinterpret_cast<void*>(
GetProcAddress(lib_handle_, name.c_str())); // NOLINT(*)
}
void Unload() {
FreeLibrary(lib_handle_);
}
#else
// Library handle
void* lib_handle_{nullptr};
// load the library
void Load(const std::string& name) {
lib_handle_ = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL);
}
BackendPackedCFunc GetFuncPtr(const std::string& name) {
return reinterpret_cast<BackendPackedCFunc>(
dlsym(lib_handle_, name.c_str()));
}
void* GetGlobalVPtr(const std::string& name) {
return dlsym(lib_handle_, name.c_str());
}
void Unload() {
dlclose(lib_handle_);
}
#endif
};
TVM_REGISTER_GLOBAL(_module_loadfile_so)
.set_body([](TVMArgs args, TVMRetValue* rv) {
std::shared_ptr<DSOModuleNode> n = std::make_shared<DSOModuleNode>();
n->Init(args[0]);
*rv = runtime::Module(n);
});
} // namespace runtime
} // namespace tvm
/*!
* Copyright (c) 2016 by Contributors
* Implementation of error handling API
* \file error_handle.cc
*/
#include <dmlc/thread_local.h>
#include <string>
#include "./runtime_base.h"
struct TVMErrorEntry {
std::string last_error;
};
typedef dmlc::ThreadLocalStore<TVMErrorEntry> TVMAPIErrorStore;
const char *TVMGetLastError() {
return TVMAPIErrorStore::Get()->last_error.c_str();
}
void TVMAPISetLastError(const char* msg) {
TVMAPIErrorStore::Get()->last_error = msg;
}
/*!
* Copyright (c) 2017 by Contributors
* \file file_util.h
* \brief Minimum file manipulation util for runtime.
*/
#ifndef TVM_RUNTIME_FILE_UTIL_H_
#define TVM_RUNTIME_FILE_UTIL_H_
#include <dmlc/logging.h>
#include <fstream>
#include <string>
namespace tvm {
namespace runtime {
/*!
* \brief Get file format from given file name or format argument.
* \param file_name The name of the file.
* \param format The format of the file.
*/
inline std::string GetFileFormat(const std::string& file_name,
const std::string& format) {
std::string fmt = format;
if (fmt.length() == 0) {
size_t pos = file_name.find_last_of(".");
if (pos != std::string::npos) {
return file_name.substr(pos + 1, file_name.length() - pos - 1);
} else {
return "";
}
} else {
return format;
}
}
/*!
* \brief Load binary file into a in-memory buffer.
* \param file_name The name of the file.
*/
inline std::string LoadBinaryFile(const std::string& file_name) {
std::ifstream fs(file_name, std::ios::in | std::ios::binary);
CHECK(!fs.fail())
<< "Cannot open " << file_name;
// get its size:
fs.seekg(0, std::ios::end);
size_t size = fs.tellg();
fs.seekg(0, std::ios::beg);
std::string data;
data.resize(size);
fs.read(&data[0], size);
return data;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_FILE_UTIL_H_
/*!
* Copyright (c) 2017 by Contributors
* \file meta_data.h
* \brief Meta data related utilities
*/
#ifndef TVM_RUNTIME_META_DATA_H_
#define TVM_RUNTIME_META_DATA_H_
#include <dmlc/json.h>
#include <string>
#include <vector>
#include "./runtime_base.h"
extern "C" {
// Function signature for generated packed function in shared library
typedef int (*BackendPackedCFunc)(void* args,
int* type_codes,
int num_args);
} // extern "C"
namespace tvm {
namespace runtime {
/*! \brief function information needed by device */
struct FunctionInfo {
std::string name;
std::vector<TVMType> arg_types;
std::vector<std::string> thread_axis_tags;
void Save(dmlc::JSONWriter *writer) const {
std::vector<std::string> sarg_types(arg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
sarg_types[i] = TVMType2String(arg_types[i]);
}
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("thread_axis_tags", thread_axis_tags);
writer->EndObject();
}
void Load(dmlc::JSONReader *reader) {
dmlc::JSONObjectReadHelper helper;
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareField("thread_axis_tags", &thread_axis_tags);
helper.ReadAllFields(reader);
arg_types.resize(sarg_types.size());
for (size_t i = 0; i < arg_types.size(); ++i) {
arg_types[i] = String2TVMType(sarg_types[i]);
}
}
};
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_META_DATA_H_
/*!
* Copyright (c) 2017 by Contributors
* \file module.cc
* \brief The global registry of packed function.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
#include <unordered_set>
#include "./file_util.h"
#include "./meta_data.h"
namespace tvm {
namespace runtime {
PackedFunc Module::GetFunction(
const std::string& name, bool query_imports) {
PackedFunc pf = node_->GetFunction(name, node_);
if (pf != nullptr) return pf;
if (query_imports) {
for (const Module& m : node_->imports_) {
pf = m.node_->GetFunction(name, m.node_);
if (pf != nullptr) return pf;
}
}
return pf;
}
void Module::Import(Module other) {
// cyclic detection.
std::unordered_set<const ModuleNode*> visited{other.node_.get()};
std::vector<const ModuleNode*> stack{other.node_.get()};
while (!stack.empty()) {
const ModuleNode* n = stack.back();
stack.pop_back();
for (const Module& m : n->imports_) {
const ModuleNode* next = m.node_.get();
if (visited.count(next)) continue;
visited.insert(next);
stack.push_back(next);
}
}
CHECK(!visited.count(node_.get()))
<< "Cyclic dependency detected during import";
node_->imports_.emplace_back(std::move(other));
}
Module Module::LoadFromFile(const std::string& file_name,
const std::string& format) {
std::string fmt = GetFileFormat(file_name, format);
CHECK(fmt.length() != 0)
<< "Cannot deduce format of file " << file_name;
if (fmt == "dll" || fmt == "dylib" || fmt == "dso") {
fmt = "so";
}
std::string load_f_name = "_module_loadfile_" + fmt;
const PackedFunc* f = Registry::Get(load_f_name);
CHECK(f != nullptr)
<< "Loader of " << format << "("
<< load_f_name << ") is not presented.";
Module m = (*f)(file_name, format);
return m;
}
const PackedFunc* ModuleNode::GetFuncFromEnv(const std::string& name) {
auto it = import_cache_.find(name);
if (it != import_cache_.end()) return it->second.get();
PackedFunc pf;
for (Module& m : this->imports_) {
pf = m.GetFunction(name, false);
if (pf != nullptr) break;
}
if (pf == nullptr) {
const PackedFunc* f = Registry::Get(name);
CHECK(f != nullptr)
<< "Cannot find function " << name
<< " in the imported modules or global registry";
return f;
} else {
std::unique_ptr<PackedFunc> f(new PackedFunc(pf));
import_cache_[name] = std::move(f);
return import_cache_.at(name).get();
}
}
TVM_REGISTER_GLOBAL(_module__GetSource)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->GetSource(args[1]);
});
TVM_REGISTER_GLOBAL(_module__ImportsSize)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = static_cast<int64_t>(
args[0].operator Module()->imports().size());
});
TVM_REGISTER_GLOBAL(_module__GetImport)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->
imports().at(args[1].operator int());
});
TVM_REGISTER_GLOBAL(_module__GetTyeKey)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator Module()->type_key();
});
TVM_REGISTER_GLOBAL(_module__LoadFromFile)
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = Module::LoadFromFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL(_module__SaveToFile)
.set_body([](TVMArgs args, TVMRetValue *ret) {
args[0].operator Module()->
SaveToFile(args[1], args[2]);
});
} // namespace runtime
} // namespace tvm
......@@ -114,6 +114,10 @@ class OpenCLWorkspace {
// Number of registered kernels
// Used to register kernel into the workspace.
size_t num_registered_kernels{0};
// The version counter, used
size_t timestamp{0};
// Ids that are freed by kernels.
std::vector<size_t> free_kernel_ids;
// the mutex for initialization
std::mutex mu;
// destructor
......@@ -139,13 +143,21 @@ class OpenCLWorkspace {
static OpenCLWorkspace* Global();
};
/*! \brief Thread local workspace */
class OpenCLThreadEntry {
public:
// The kernel entry and version.
struct KTEntry {
// The kernel handle.
cl_kernel kernel{nullptr};
// timestamp used to recognize stale kernel
size_t version{0};
};
/*! \brief The current context */
TVMContext context;
/*! \brief The thread-local kernel table */
std::vector<cl_kernel> kernel_table;
std::vector<KTEntry> kernel_table;
OpenCLThreadEntry() {
context.dev_id = 0;
......@@ -155,29 +167,6 @@ class OpenCLThreadEntry {
static OpenCLThreadEntry* ThreadLocal();
};
} // namespace cl
/*!
* \brief Automatically detect and set cuda device.
* \param args The arguments.
*/
inline void AutoSetOpenCLContext(const TVMArgs& args) {
// TODO(tqchen): merge this with CUDA logic.
int dev_id = -1;
for (int i = 0; i < args.size(); ++i) {
if (args.type_codes[i] == kArrayHandle) {
TVMContext ctx = static_cast<TVMArray*>(
args.values[i].v_handle)->ctx;
CHECK_EQ(ctx.dev_mask, kOpenCL)
<< "All operands need to be GPU";
if (dev_id == -1) {
dev_id = ctx.dev_id;
} else {
CHECK_EQ(dev_id, ctx.dev_id)
<< "Operands comes from different devices ";
}
}
}
cl::OpenCLThreadEntry::ThreadLocal()->context.dev_id = dev_id;
}
} // namespace runtime
} // namespace tvm
#endif // TVM_OPENCL_RUNTIME
......
......@@ -11,38 +11,21 @@
#include <memory>
#include <vector>
#include <string>
#include "../meta_data.h"
namespace tvm {
namespace runtime {
/*!
* \brief Handle execution of OPENCL kernels as PackedFunc.
* It wraps around driver API to work with OPENCL runtime API.
* \brief create a cuda module from data.
*
* \param data The module data, can be ptx, cubin
* \param fmt The format of the data, can be "clbin", "cl"
* \param fmap The map function information map of each function.
*/
class OpenCLModule {
public:
/*!
* \brief Get OpenCL Kernel launch wrapped as PackedFunc
* \param func_name The name of the function.
* \param arg_types The type of each argument in the function.
* \param thread_axis_tags The tag sequence of the thread axis.
*/
PackedFunc GetPackedFunc(
const std::string& func_name,
const std::vector<TVMType> arg_types,
const std::vector<std::string> thread_axis_tags) const;
/*!
* \brief create a OpenCL module from data.
* \param source The module data.
*/
static OpenCLModule CreateWithSource(std::string source);
/*! \brief hidden internal data structure. */
class Internal;
private:
std::shared_ptr<Internal> ptr_;
};
Module OpenCLModuleCreate(
std::string data,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap);
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
/*!
* Copyright (c) 2017 by Contributors
* \file packed_func_registry.cc
* \file registry.cc
* \brief The global registry of packed function.
*/
#include <dmlc/logging.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <unordered_map>
#include <mutex>
#include <memory>
#include "./runtime_base.h"
namespace tvm {
namespace runtime {
struct PackedFuncRegistry {
struct Registry::Manager {
// map storing the functions.
// We delibrately used raw pointer
// This is because PackedFunc can contain callbacks into the host languge(python)
// and the resource can become invalid because of indeterminstic order of destruction.
// The resources will only be recycled during program exit.
std::unordered_map<std::string, PackedFunc*> fmap;
std::unordered_map<std::string, Registry*> fmap;
std::mutex mutex;
static PackedFuncRegistry* Global() {
static PackedFuncRegistry inst;
static Manager* Global() {
static Manager inst;
return &inst;
}
};
const PackedFunc& PackedFunc::RegisterGlobal(
const std::string& name, PackedFunc f) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
CHECK(it == r->fmap.end())
Registry& Registry::set_body(PackedFunc f) { // NOLINT(*)
func_ = f;
return *this;
}
Registry& Registry::Register(const std::string& name) { // NOLINT(*)
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
CHECK(it == m->fmap.end())
<< "Global PackedFunc " << name << " is already registered";
PackedFunc* fp = new PackedFunc(f);
r->fmap[name] = fp;
return *fp;
Registry* r = new Registry();
r->name_ = name;
m->fmap[name] = r;
return *r;
}
const PackedFunc& PackedFunc::GetGlobal(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
CHECK(it != r->fmap.end())
<< "Global PackedFunc " << name << " is not registered";
return *(it->second);
bool Registry::Remove(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) return false;
m->fmap.erase(it);
return true;
}
bool PackedFunc::GlobalExist(const std::string& name) {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
auto it = r->fmap.find(name);
return it != r->fmap.end();
const PackedFunc* Registry::Get(const std::string& name) {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
auto it = m->fmap.find(name);
if (it == m->fmap.end()) return nullptr;
return &(it->second->func_);
}
std::vector<std::string> PackedFunc::ListGlobalNames() {
PackedFuncRegistry* r = PackedFuncRegistry::Global();
std::vector<std::string> Registry::ListNames() {
Manager* m = Manager::Global();
std::lock_guard<std::mutex>(m->mutex);
std::vector<std::string> keys;
keys.reserve(r->fmap.size());
for (const auto &kv : r->fmap) {
keys.reserve(m->fmap.size());
for (const auto &kv : m->fmap) {
keys.push_back(kv.first);
}
return keys;
......@@ -78,26 +90,27 @@ typedef dmlc::ThreadLocalStore<TVMFuncThreadLocalEntry> TVMFuncThreadLocalStore;
int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f) {
using tvm::runtime::PackedFunc;
API_BEGIN();
PackedFunc::RegisterGlobal(name, *static_cast<PackedFunc*>(f));
tvm::runtime::Registry::Register(name)
.set_body(*static_cast<tvm::runtime::PackedFunc*>(f));
API_END();
}
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
using tvm::runtime::PackedFunc;
API_BEGIN();
const PackedFunc& f = PackedFunc::GetGlobal(name);
*out = (TVMFunctionHandle)(&f); // NOLINT(*)
const tvm::runtime::PackedFunc* fp =
tvm::runtime::Registry::Get(name);
CHECK(fp != nullptr)
<< "Cannot find global function " << name;
*out = (TVMFunctionHandle)(fp); // NOLINT(*)
API_END();
}
int TVMFuncListGlobalNames(int *out_size,
const char*** out_array) {
using tvm::runtime::PackedFunc;
API_BEGIN();
TVMFuncThreadLocalEntry *ret = TVMFuncThreadLocalStore::Get();
ret->ret_vec_str = PackedFunc::ListGlobalNames();
ret->ret_vec_str = tvm::runtime::Registry::ListNames();
ret->ret_vec_charp.clear();
for (size_t i = 0; i < ret->ret_vec_str.size(); ++i) {
ret->ret_vec_charp.push_back(ret->ret_vec_str[i].c_str());
......
......@@ -19,20 +19,15 @@ def test_add():
s[C].vectorize(x)
# one line to build the function.
def check_device(target):
codes = []
fadd = tvm.build(s, [A, B, C],
target, record_codes=codes,
name="myadd")
if target == "cuda":
ctx = tvm.gpu(0)
else:
ctx = tvm.cl(0)
if not ctx.enabled:
def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host):
return
for c in codes[1:]:
print(c)
if not tvm.codegen.target_enabled(device):
return
fadd = tvm.build(s, [A, B, C],
device, host,
name="myadd")
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# launch the kernel.
n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
......@@ -43,7 +38,7 @@ def test_add():
c.asnumpy(), a.asnumpy() + b.asnumpy())
tvm.init_opencl()
check_device("cuda")
check_device("cuda", "llvm")
check_device("opencl")
......
......@@ -52,16 +52,16 @@ def test_gemm():
# lowering test
s.normalize()
def check_device(target):
codes = []
f = tvm.build(s, [A, B, C], target, record_codes=codes,
max_auto_unroll_step=max_auto_unroll_step)
if target == "cuda":
ctx = tvm.gpu(0)
else:
ctx = tvm.cl(0)
if not ctx.enabled:
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host):
return
if not tvm.codegen.target_enabled(device):
return
f = tvm.build(s, [A, B, C], device, host,
max_auto_unroll_step=max_auto_unroll_step)
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# launch the kernel.
n = nn
m = n
......
......@@ -17,29 +17,29 @@ def test_sum():
_, x = s[B].split(B.op.axis[0], factor=num_thread, outer=block_x)
_, x = s[B].split(x, outer=thread_x)
tvm.init_opencl()
codes = []
fsum = tvm.build(s,
args=[A, B],
target="opencl", name="myadd",
record_codes=codes)
for c in codes:
print(c)
num_device = 1
for i in range(num_device):
ctx = tvm.opencl(i)
if not ctx.enabled:
continue
# one line to build the function.
def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host):
return
if not tvm.codegen.target_enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
fsum = tvm.build(s,
args=[A, B],
target=device, target_host=host,
name="mysum")
# launch the kernel.
n = 1028
m = 129
#a = tvm.nd.array(np.zeros((n, m)).astype(A.dtype), ctx)
a = tvm.nd.array(np.random.uniform(size=(n, m)).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
fsum(a, b)
np.testing.assert_allclose(
b.asnumpy(), np.sum(a.asnumpy(), axis=1), rtol=1e-4)
tvm.init_opencl()
check_device("cuda")
check_device("opencl")
if __name__ == "__main__":
test_sum()
......@@ -22,20 +22,15 @@ def test_scan():
_, x = s[s_update].split(x, outer=thread_x)
# one line to build the function.
def check_device(target):
codes = []
def check_device(device, host="stackvm"):
if not tvm.codegen.target_enabled(host):
return
if not tvm.codegen.target_enabled(device):
return
fscan = tvm.build(s, [X, res],
target, record_codes=codes,
device, host,
name="myscan")
if target == "cuda":
ctx = tvm.gpu(0)
else:
ctx = tvm.cl(0)
if not ctx.enabled:
return
for c in codes[1:]:
print(c)
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
# launch the kernel.
n = 1024
m = 10
......@@ -48,6 +43,7 @@ def test_scan():
tvm.init_opencl()
check_device("cuda")
check_device("opencl")
if __name__ == "__main__":
......
......@@ -23,59 +23,33 @@ def test_add_pipeline():
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
fsplits = tvm.ir_pass.SplitHostDevice(fapi)
def check_cuda():
output_ssa = False
for f in fsplits[1:]:
print(tvm.codegen.CompileToC(f, output_ssa, "cuda"))
def check_target(device, host="stackvm"):
if not tvm.codegen.target_enabled(host):
return
if not tvm.codegen.target_enabled(device):
return
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
mhost = tvm.codegen.build(fsplits[0], host)
mdev = tvm.codegen.build(fsplits[1:], device)
mhost.import_module(mdev)
code = mdev.get_source()
f = mhost.entry_func
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
check_target("cuda", host="stackvm")
check_target("cuda", host="llvm")
# build and invoke the kernel.
fcuda = tvm.codegen.BuildNVRTC(fsplits, "stackvm")
num_device = 1
for i in range(num_device):
ctx = tvm.gpu(i)
if not ctx.enabled:
continue
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
fcuda(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
def check_opencl():
output_ssa = False
for f in fsplits[1:]:
print(tvm.codegen.CompileToC(f, output_ssa, "opencl"))
# build and invoke the kernel.
fcl = tvm.codegen.BuildOpenCL(fsplits, "stackvm")
# Disable OpenCL runtime test for now,
# since the local worksize on CPU might be too large.
num_device = 0
for i in range(num_device):
ctx = tvm.cl(i)
if not ctx.enabled:
continue
# launch the kernel.
n = 1027
a = tvm.nd.array(np.random.uniform(size=n).astype(Ab.dtype), ctx)
b = tvm.nd.array(np.random.uniform(size=n).astype(Bb.dtype), ctx)
c = tvm.nd.array(np.zeros(n, dtype=Cb.dtype), ctx)
fcl(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
tvm.init_opencl()
if tvm.cl(0).enabled:
check_opencl()
if tvm.gpu(0).enabled:
check_cuda()
if __name__ == "__main__":
test_add_pipeline()
import tvm
import numpy as np
def tvm_call_global(*args):
def tvm_call_packed(*args):
args = tvm.convert(args)
return tvm.make.Call("int32", "tvm_call_global", args, 4, None, 0)
return tvm.make.Call("int32", "tvm_call_packed", args, 4, None, 0)
def run_jit(fapi, check):
for target in ["stackvm"]:
if target == "llvm":
f = tvm.codegen.BuildLLVM(fapi)
else:
f = tvm.codegen.BuildStackVM(fapi)
for target in ["llvm", "stackvm"]:
if not tvm.codegen.target_enabled(target):
continue
f = tvm.codegen.build(fapi, target)
s = f.get_source()
check(f)
def test_stack_vm_basic():
a = tvm.nd.array(np.zeros(10, dtype='float32'))
@tvm.register_func
......@@ -25,8 +24,8 @@ def test_stack_vm_basic():
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), tvm.float32)
stmt = tvm.make.Evaluate(tvm_call_global("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 1)
stmt = tvm.make.Evaluate(tvm_call_packed("tvm_call_back_get_shape", Ab.shape[0]))
fapi = tvm.ir_pass.MakeAPI(stmt, "print_shape", [Ab], 0)
run_jit(fapi, lambda f: f(a))
......@@ -47,9 +46,8 @@ def test_stack_vm_loop():
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1),
tvm.make.Evaluate(tvm_call_global("tvm_stack_vm_print", i))))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 1)
f = tvm.codegen.BuildStackVM(fapi)
tvm.make.Evaluate(tvm_call_packed("tvm_stack_vm_print", i))))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
def check(f):
f(a)
......@@ -71,7 +69,7 @@ def test_stack_vm_cond():
tvm.make.Load(dtype, Ab.data, i) + 1, i + 1),
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 2, i + 1)))
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 1)
fapi = tvm.ir_pass.MakeAPI(stmt, "test", [Ab], 0)
def check(f):
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f(a)
......@@ -94,11 +92,13 @@ def test_llvm_add_pipeline():
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
stmt = tvm.ir_pass.Simplify(stmt)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 3)
fapi = tvm.ir_pass.MakeAPI(stmt, "myadd", [Ab, Bb, Cb], 0)
def check_llvm():
if not tvm.codegen.target_enabled("llvm"):
return
# build and invoke the kernel.
f = tvm.codegen.BuildLLVM(fapi)
f = tvm.codegen.build(fapi, "llvm")
ctx = tvm.cpu(0)
# launch the kernel.
n = 1027
......@@ -108,10 +108,10 @@ def test_llvm_add_pipeline():
f(a, b, c)
np.testing.assert_allclose(
c.asnumpy(), a.asnumpy() + b.asnumpy())
#check_llvm()
check_llvm()
if __name__ == "__main__":
test_stack_vm_cond()
test_stack_vm_basic()
test_stack_vm_cond()
test_stack_vm_loop()
test_llvm_add_pipeline()
import tvm
from tvm.addon import cc_compiler as cc
import os
import tempfile
import numpy as np
def test_dso_module_load():
if not tvm.codegen.target_enabled("llvm"):
return
dtype = 'int64'
temp_dir = tempfile.mkdtemp()
def save_object(names):
n = tvm.Var('n')
Ab = tvm.Buffer((n, ), dtype)
i = tvm.Var('i')
# for i in 0 to n-1:
stmt = tvm.make.For(
i, 0, n - 1, 0, 0,
tvm.make.Store(Ab.data,
tvm.make.Load(dtype, Ab.data, i) + 1,
i + 1))
fapi = tvm.ir_pass.MakeAPI(stmt, "ramp", [Ab], 0)
m = tvm.codegen.build(fapi, "llvm")
for name in names:
m.save(name)
path_obj = "%s/test.o" % temp_dir
path_ll = "%s/test.ll" % temp_dir
path_bc = "%s/test.bc" % temp_dir
path_dso = "%s/test.so" % temp_dir
save_object([path_obj, path_ll, path_bc])
cc.create_shared(path_dso, [path_obj])
f1 = tvm.module.load(path_dso)
f2 = tvm.module.load(path_dso)
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f1(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
a = tvm.nd.array(np.zeros(10, dtype=dtype))
f2(a)
np.testing.assert_equal(a.asnumpy(), np.arange(a.shape[0]))
files = [path_obj, path_ll, path_bc, path_dso]
for f in files:
os.remove(f)
os.rmdir(temp_dir)
def test_cuda_module_load():
pass
if __name__ == "__main__":
test_dso_module_load()
......@@ -17,8 +17,9 @@ def test_makeapi():
Cb = tvm.Buffer(C.shape, C.dtype, name='C')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B:Bb, C:Cb})
num_packed_args = 2
f = tvm.ir_pass.MakeAPI(stmt, "myadd", [n, Ab, Bb, Cb], num_packed_args)
num_unpacked_args = 2
f = tvm.ir_pass.MakeAPI(
stmt, "myadd", [n, Ab, Bb, Cb], num_unpacked_args)
assert(f.handle_data_type[Ab.data].dtype == Ab.dtype)
assert(len(f.args) == 5)
output_ssa = False
......
......@@ -21,7 +21,7 @@ def test_storage_sync():
Ab = tvm.Buffer(A.shape, A.dtype, name='A')
A2b = tvm.Buffer(A2.shape, A2.dtype, name='A2')
stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, A2: A2b})
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 2)
f = tvm.ir_pass.MakeAPI(stmt, "test", [Ab, A2b], 0)
flist = tvm.ir_pass.SplitHostDevice(f)
f = flist[1]
f = tvm.ir_pass.StorageSync(f, "shared")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment