Commit c89cd59a by Tianqi Chen Committed by GitHub

[MODULE/DSO] Support pack everything into one shared library. (#133)

* [MODULE/DSO] Support pack everything into one shared library.

* fix osx load
parent bf8a5c07
Subproject commit 2b75a0ce6f191ad0fcb5319039b41e990968542a Subproject commit a6c5701219e635fea808d264aefc5b03c3aec314
...@@ -31,6 +31,16 @@ using runtime::TVMRetValue; ...@@ -31,6 +31,16 @@ using runtime::TVMRetValue;
*/ */
runtime::Module Build(const Array<LoweredFunc>& funcs, runtime::Module Build(const Array<LoweredFunc>& funcs,
const std::string& target); const std::string& target);
/*!
* \brief Pack imported device library to a C file.
* Compile the C file and link with the host library
* will allow the DSO loader to automatically discover and import
* the dependency from the shared library.
*
* \param m The host module with the imports.
* \return cstr The C string representation of the file.
*/
std::string PackImportsToC(const runtime::Module& m);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#ifndef TVM_RUNTIME_MODULE_H_ #ifndef TVM_RUNTIME_MODULE_H_
#define TVM_RUNTIME_MODULE_H_ #define TVM_RUNTIME_MODULE_H_
#include <dmlc/io.h>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
...@@ -58,6 +59,8 @@ class Module { ...@@ -58,6 +59,8 @@ class Module {
const std::string& format); const std::string& format);
/*! \return internal container */ /*! \return internal container */
inline ModuleNode* operator->(); inline ModuleNode* operator->();
/*! \return internal container */
inline const ModuleNode* operator->() const;
private: private:
std::shared_ptr<ModuleNode> node_; std::shared_ptr<ModuleNode> node_;
...@@ -112,13 +115,20 @@ class ModuleNode { ...@@ -112,13 +115,20 @@ class ModuleNode {
virtual void SaveToFile(const std::string& file_name, virtual void SaveToFile(const std::string& file_name,
const std::string& format) = 0; const std::string& format) = 0;
/*! /*!
* \brief Save the module to binary stream.
* \param stream The binary stream to save to.
* \note It is recommended to implement this for device modules,
* but not necessarily host modules.
* We can use this to do AOT loading of bundled device functions.
*/
virtual void SaveToBinary(dmlc::Stream* stream) = 0;
/*!
* \brief Get the source code of module, when available. * \brief Get the source code of module, when available.
* \param format Format of the source code, can be empty by default. * \param format Format of the source code, can be empty by default.
* \return Possible source code when available. * \return Possible source code when available.
*/ */
virtual std::string GetSource( virtual std::string GetSource(
const std::string& format = "") = 0; const std::string& format = "") = 0;
/*! /*!
* \brief Get a function from current environment * \brief Get a function from current environment
* The environment includes all the imports as well as Global functions. * The environment includes all the imports as well as Global functions.
...@@ -132,10 +142,12 @@ class ModuleNode { ...@@ -132,10 +142,12 @@ class ModuleNode {
return imports_; return imports_;
} }
private: protected:
friend class Module; friend class Module;
/*! \brief The modules this module depend on */ /*! \brief The modules this module depend on */
std::vector<Module> imports_; std::vector<Module> imports_;
private:
/*! \brief Cache used by GetImport */ /*! \brief Cache used by GetImport */
std::unordered_map<std::string, std::unordered_map<std::string,
std::unique_ptr<PackedFunc> > import_cache_; std::unique_ptr<PackedFunc> > import_cache_;
...@@ -145,6 +157,10 @@ class ModuleNode { ...@@ -145,6 +157,10 @@ class ModuleNode {
namespace symbol { namespace symbol {
/*! \brief Global variable to store module context. */ /*! \brief Global variable to store module context. */
constexpr const char* tvm_module_ctx = "__tvm_module_ctx"; constexpr const char* tvm_module_ctx = "__tvm_module_ctx";
/*! \brief Global variable to store device module blob */
constexpr const char* tvm_dev_mblob = "__tvm_dev_mblob";
/*! \brief Number of bytes of device module blob. */
constexpr const char* tvm_dev_mblob_nbytes = "__tvm_dev_mblob_nbytes";
/*! \brief global function to set device */ /*! \brief global function to set device */
constexpr const char* tvm_set_device = "__tvm_set_device"; constexpr const char* tvm_set_device = "__tvm_set_device";
/*! \brief Auxiliary counter to global barrier. */ /*! \brief Auxiliary counter to global barrier. */
...@@ -160,6 +176,10 @@ inline ModuleNode* Module::operator->() { ...@@ -160,6 +176,10 @@ inline ModuleNode* Module::operator->() {
return node_.get(); return node_.get();
} }
inline const ModuleNode* Module::operator->() const {
return node_.get();
}
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from ._ffi.function import ModuleBase, _set_class_module from ._ffi.function import ModuleBase, _set_class_module
from ._ffi.function import _init_api from ._ffi.function import _init_api
from .contrib import cc_compiler as _cc, util as _util
class Module(ModuleBase): class Module(ModuleBase):
"""Module container of all TVM generated functions""" """Module container of all TVM generated functions"""
...@@ -44,15 +44,46 @@ class Module(ModuleBase): ...@@ -44,15 +44,46 @@ class Module(ModuleBase):
def save(self, file_name, fmt=""): def save(self, file_name, fmt=""):
"""Save the module to file. """Save the module to file.
This do not save the dependent device modules.
See also export_shared
Parameters Parameters
---------- ----------
file_name : str file_name : str
The name of the file. The name of the file.
fmt : str fmt : str
The format of the file. The format of the file.
See Also
--------
Module.export_library : export the module to shared library.
""" """
_SaveToFile(self, file_name, fmt) _SaveToFile(self, file_name, fmt)
def export_library(self, file_name):
"""Export the module and its imported device code one library.
This function only works on host llvm modules.
It will pack all the imported modules
Parameters
----------
file_name : str
The name of the shared library.
"""
if self.type_key != "llvm":
raise ValueError("Only llvm support export shared")
temp = _util.tempdir()
path_obj = temp.relpath("lib.o")
self.save(path_obj)
files = [path_obj]
if self.imported_modules:
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self))
files.append(path_cc)
_cc.create_shared(file_name, files)
def load(path, fmt=""): def load(path, fmt=""):
"""Load module from file """Load module from file
......
...@@ -20,5 +20,10 @@ TVM_REGISTER_API("codegen._Build") ...@@ -20,5 +20,10 @@ TVM_REGISTER_API("codegen._Build")
*ret = Build(args[0], args[1]); *ret = Build(args[0], args[1]);
} }
}); });
TVM_REGISTER_API("module._PackImportsToC")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = PackImportsToC(args[0]);
});
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
#include <tvm/ir_pass.h> #include <tvm/ir_pass.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <dmlc/memory_io.h>
#include <sstream>
#include <iostream>
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -32,5 +35,50 @@ runtime::Module Build(const Array<LoweredFunc>& funcs, ...@@ -32,5 +35,50 @@ runtime::Module Build(const Array<LoweredFunc>& funcs,
return m; return m;
} }
std::string PackImportsToC(const runtime::Module& mod) {
std::string bin;
dmlc::MemoryStringStream ms(&bin);
dmlc::Stream* stream = &ms;
uint64_t sz = static_cast<uint64_t>(mod->imports().size());
stream->Write(sz);
for (runtime::Module im : mod->imports()) {
CHECK_EQ(im->imports().size(), 0U)
<< "Only support simply one-level hierachy";
std::string tkey = im->type_key();
std::string bin;
stream->Write(tkey);
im->SaveToBinary(stream);
}
// translate to C program
std::ostringstream os;
os << "#ifdef __cplusplus\n"
<< "extern \"C\" {\n"
<< "#endif\n";
os << "extern const char " << runtime::symbol::tvm_dev_mblob << "[];\n";
os << "extern const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes << ";\n";
os << "const char " << runtime::symbol::tvm_dev_mblob
<< "[" << bin.length() << "] = {\n ";
os << std::hex;
size_t nunit = 80 / 4;
for (size_t i = 0; i < bin.length(); ++i) {
// sperators
if (i != 0) {
if (i % nunit == 0) {
os << ",\n ";
} else {
os << ",";
}
}
int c = bin[i];
os << "0x" << (c & 0xff);
}
os << "\n};\n"
<< "const unsigned long " << runtime::symbol::tvm_dev_mblob_nbytes
<< " = " << std::dec << bin.length() << "UL;\n"
<< "#ifdef __cplusplus\n"
<< "}\n"
<< "#endif\n";
return os.str();
}
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
...@@ -19,7 +19,7 @@ using runtime::TVMArgs; ...@@ -19,7 +19,7 @@ using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
class LLVMModuleNode : public runtime::ModuleNode { class LLVMModuleNode final : public runtime::ModuleNode {
public: public:
~LLVMModuleNode() { ~LLVMModuleNode() {
module_.reset(); module_.reset();
...@@ -84,6 +84,10 @@ class LLVMModuleNode : public runtime::ModuleNode { ...@@ -84,6 +84,10 @@ class LLVMModuleNode : public runtime::ModuleNode {
dest.close(); dest.close();
} }
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "LLVMModule: SaveToBinary not supported";
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
std::string type_str; std::string type_str;
llvm::raw_string_ostream rso(type_str); llvm::raw_string_ostream rso(type_str);
......
...@@ -13,7 +13,7 @@ using runtime::TVMArgs; ...@@ -13,7 +13,7 @@ using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
// Simulator function // Simulator function
class SourceModuleNode : public runtime::ModuleNode { class SourceModuleNode final : public runtime::ModuleNode {
public: public:
SourceModuleNode(std::string code, SourceModuleNode(std::string code,
std::string fmt) std::string fmt)
...@@ -30,10 +30,16 @@ class SourceModuleNode : public runtime::ModuleNode { ...@@ -30,10 +30,16 @@ class SourceModuleNode : public runtime::ModuleNode {
<< " build TVM with \'" << fmt_ << "\' runtime support"; << " build TVM with \'" << fmt_ << "\' runtime support";
return PackedFunc(); return PackedFunc();
} }
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
LOG(FATAL) << "not implemented"; LOG(FATAL) << "SourceModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "SourceModule: SaveToBinary not supported";
} }
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
return code_; return code_;
} }
......
...@@ -35,7 +35,11 @@ class StackVMModuleNode : public runtime::ModuleNode { ...@@ -35,7 +35,11 @@ class StackVMModuleNode : public runtime::ModuleNode {
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
LOG(FATAL) << "StackVM do not support SaveToFile"; LOG(FATAL) << "StackVMModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "StackVMModule: SaveToBinary not supported";
} }
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
......
...@@ -59,7 +59,11 @@ class VerilogModuleNode : public runtime::ModuleNode { ...@@ -59,7 +59,11 @@ class VerilogModuleNode : public runtime::ModuleNode {
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
LOG(FATAL) << "not implemented"; LOG(FATAL) << "VerilogModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "VerilogModule: SaveToBinary not supported";
} }
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
......
...@@ -69,6 +69,12 @@ class CUDAModuleNode : public runtime::ModuleNode { ...@@ -69,6 +69,12 @@ class CUDAModuleNode : public runtime::ModuleNode {
SaveBinaryToFile(file_name, data_); SaveBinaryToFile(file_name, data_);
} }
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
if (format == fmt_) return data_; if (format == fmt_) return data_;
if (cuda_source_.length() != 0) { if (cuda_source_.length() != 0) {
...@@ -242,8 +248,8 @@ Module CUDAModuleCreate( ...@@ -242,8 +248,8 @@ Module CUDAModuleCreate(
} }
// Load module from module. // Load module from module.
Module CUDAModuleLoad(const std::string& file_name, Module CUDAModuleLoadFile(const std::string& file_name,
const std::string& format) { const std::string& format) {
std::string data; std::string data;
std::unordered_map<std::string, FunctionInfo> fmap; std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
...@@ -253,14 +259,30 @@ Module CUDAModuleLoad(const std::string& file_name, ...@@ -253,14 +259,30 @@ Module CUDAModuleLoad(const std::string& file_name,
return CUDAModuleCreate(data, fmt, fmap, std::string()); return CUDAModuleCreate(data, fmt, fmap, std::string());
} }
Module CUDAModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&data);
return CUDAModuleCreate(data, fmt, fmap, std::string());
}
TVM_REGISTER_GLOBAL("module.loadfile_cubin") TVM_REGISTER_GLOBAL("module.loadfile_cubin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]); *rv = CUDAModuleLoadFile(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("module.loadfile_ptx") TVM_REGISTER_GLOBAL("module.loadfile_ptx")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoad(args[0], args[1]); *rv = CUDAModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_cuda")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = CUDAModuleLoadBinary(args[0]);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
* \file dso_module.cc * \file dso_module.cc
* \brief Module to load from dynamic shared library. * \brief Module to load from dynamic shared library.
*/ */
#include <dmlc/memory_io.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
...@@ -19,7 +20,7 @@ namespace runtime { ...@@ -19,7 +20,7 @@ namespace runtime {
// Module to load from dynamic shared libary. // Module to load from dynamic shared libary.
// This is the default module TVM used for hostside AOT // This is the default module TVM used for hostside AOT
class DSOModuleNode : public ModuleNode { class DSOModuleNode final : public ModuleNode {
public: public:
~DSOModuleNode() { ~DSOModuleNode() {
if (lib_handle_) Unload(); if (lib_handle_) Unload();
...@@ -49,7 +50,11 @@ class DSOModuleNode : public ModuleNode { ...@@ -49,7 +50,11 @@ class DSOModuleNode : public ModuleNode {
void SaveToFile(const std::string& file_name, void SaveToFile(const std::string& file_name,
const std::string& format) final { const std::string& format) final {
LOG(FATAL) << "Cannot save dso to another file"; LOG(FATAL) << "DSOModule: SaveToFile not supported";
}
void SaveToBinary(dmlc::Stream* stream) final {
LOG(FATAL) << "DSOModule: SaveToBinary not supported";
} }
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
...@@ -66,6 +71,33 @@ class DSOModuleNode : public ModuleNode { ...@@ -66,6 +71,33 @@ class DSOModuleNode : public ModuleNode {
if (ctx_addr != nullptr) { if (ctx_addr != nullptr) {
*ctx_addr = this; *ctx_addr = this;
} }
// Load the imported modules
const char* dev_mblob =
reinterpret_cast<const char*>(
GetGlobalVPtr(runtime::symbol::tvm_dev_mblob));
const unsigned long* dev_mblob_nbytes = // NOLINT(*)
reinterpret_cast<const unsigned long*>( // NOLINT(*)
GetGlobalVPtr(runtime::symbol::tvm_dev_mblob_nbytes));
if (dev_mblob != nullptr) {
CHECK(dev_mblob_nbytes != nullptr);
dmlc::MemoryFixedSizeStream fs(
(void*)dev_mblob, dev_mblob_nbytes[0]); // NOLINT(*)
dmlc::Stream* stream = &fs;
uint64_t size;
CHECK(stream->Read(&size));
for (uint64_t i = 0; i < size; ++i) {
std::string tkey;
CHECK(stream->Read(&tkey));
std::string fkey = "module.loadbinary_" + tkey;
const PackedFunc* f = Registry::Get(fkey);
CHECK(f != nullptr)
<< "Loader of " << tkey << "("
<< fkey << ") is not presented.";
Module m = (*f)(static_cast<void*>(stream));
this->imports_.push_back(m);
}
}
} }
private: private:
......
...@@ -36,6 +36,19 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { ...@@ -36,6 +36,19 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
} }
} }
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(thread_axis_tags);
}
bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&thread_axis_tags)) return false;
return true;
}
std::string GetFileFormat(const std::string& file_name, std::string GetFileFormat(const std::string& file_name,
const std::string& format) { const std::string& format) {
std::string fmt = format; std::string fmt = format;
......
...@@ -29,7 +29,13 @@ struct FunctionInfo { ...@@ -29,7 +29,13 @@ struct FunctionInfo {
void Save(dmlc::JSONWriter *writer) const; void Save(dmlc::JSONWriter *writer) const;
void Load(dmlc::JSONReader *reader); void Load(dmlc::JSONReader *reader);
void Save(dmlc::Stream *writer) const;
bool Load(dmlc::Stream *reader);
}; };
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
namespace dmlc {
DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::FunctionInfo, true);
} // namespace dmlc
#endif // TVM_RUNTIME_META_DATA_H_ #endif // TVM_RUNTIME_META_DATA_H_
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#if TVM_METAL_RUNTIME #if TVM_METAL_RUNTIME
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
#include <array> #include <array>
...@@ -54,6 +55,11 @@ class MetalModuleNode final :public runtime::ModuleNode { ...@@ -54,6 +55,11 @@ class MetalModuleNode final :public runtime::ModuleNode {
SaveBinaryToFile(file_name, data_); SaveBinaryToFile(file_name, data_);
} }
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
if (format == fmt_) return data_; if (format == fmt_) return data_;
if (source_.length() != 0) { if (source_.length() != 0) {
...@@ -261,8 +267,8 @@ Module MetalModuleCreate( ...@@ -261,8 +267,8 @@ Module MetalModuleCreate(
} }
// Load module from module. // Load module from module.
Module MetalModuleLoad(const std::string& file_name, Module MetalModuleLoadFile(const std::string& file_name,
const std::string& format) { const std::string& format) {
std::string data; std::string data;
std::unordered_map<std::string, FunctionInfo> fmap; std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
...@@ -272,9 +278,25 @@ Module MetalModuleLoad(const std::string& file_name, ...@@ -272,9 +278,25 @@ Module MetalModuleLoad(const std::string& file_name,
return MetalModuleCreate(data, fmt, fmap, ""); return MetalModuleCreate(data, fmt, fmap, "");
} }
Module MetalModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&data);
return MetalModuleCreate(data, fmt, fmap, "");
}
TVM_REGISTER_GLOBAL("module.loadfile_metal") TVM_REGISTER_GLOBAL("module.loadfile_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = MetalModuleLoad(args[0], args[1]); *rv = MetalModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_metal")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = MetalModuleLoadBinary(args[0]);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#if TVM_OPENCL_RUNTIME #if TVM_OPENCL_RUNTIME
#include <dmlc/memory_io.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <vector> #include <vector>
#include <string> #include <string>
...@@ -78,6 +79,12 @@ class OpenCLModuleNode : public ModuleNode { ...@@ -78,6 +79,12 @@ class OpenCLModuleNode : public ModuleNode {
SaveBinaryToFile(file_name, data_); SaveBinaryToFile(file_name, data_);
} }
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(data_);
}
std::string GetSource(const std::string& format) final { std::string GetSource(const std::string& format) final {
if (format == fmt_) return data_; if (format == fmt_) return data_;
if (fmt_ == "cl") { if (fmt_ == "cl") {
...@@ -272,8 +279,8 @@ Module OpenCLModuleCreate( ...@@ -272,8 +279,8 @@ Module OpenCLModuleCreate(
} }
// Load module from module. // Load module from module.
Module OpenCLModuleLoad(const std::string& file_name, Module OpenCLModuleLoadFile(const std::string& file_name,
const std::string& format) { const std::string& format) {
std::string data; std::string data;
std::unordered_map<std::string, FunctionInfo> fmap; std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format); std::string fmt = GetFileFormat(file_name, format);
...@@ -283,14 +290,30 @@ Module OpenCLModuleLoad(const std::string& file_name, ...@@ -283,14 +290,30 @@ Module OpenCLModuleLoad(const std::string& file_name,
return OpenCLModuleCreate(data, fmt, fmap); return OpenCLModuleCreate(data, fmt, fmap);
} }
Module OpenCLModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt;
stream->Read(&fmt);
stream->Read(&fmap);
stream->Read(&data);
return OpenCLModuleCreate(data, fmt, fmap);
}
TVM_REGISTER_GLOBAL("module.loadfile_cl") TVM_REGISTER_GLOBAL("module.loadfile_cl")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]); *rv = OpenCLModuleLoadFile(args[0], args[1]);
}); });
TVM_REGISTER_GLOBAL("module.loadfile_clbin") TVM_REGISTER_GLOBAL("module.loadfile_clbin")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoad(args[0], args[1]); *rv = OpenCLModuleLoadFile(args[0], args[1]);
});
TVM_REGISTER_GLOBAL("module.loadbinary_opencl")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = OpenCLModuleLoadBinary(args[0]);
}); });
} // namespace runtime } // namespace runtime
} // namespace tvm } // namespace tvm
......
...@@ -95,7 +95,6 @@ def test_add(): ...@@ -95,7 +95,6 @@ def test_add():
device, device,
name="myadd") name="myadd")
ctx = tvm.context(device, 0) ctx = tvm.context(device, 0)
print(fadd.imported_modules[0].get_source())
# launch the kernel. # launch the kernel.
n = 1024 n = 1024
a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx) a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
......
...@@ -68,5 +68,39 @@ def test_dso_module_load(): ...@@ -68,5 +68,39 @@ def test_dso_module_load():
"python %s %s %s" % (path_runtime_py, path_dso, dtype), "python %s %s %s" % (path_runtime_py, path_dso, dtype),
shell=True) shell=True)
def test_device_module_dump():
# graph
n = tvm.convert(1024)
A = tvm.placeholder((n,), name='A')
B = tvm.compute(A.shape, lambda *i: A(*i) + 1.0, name='B')
s = tvm.create_schedule(B.op)
# create iter var and assign them tags.
num_thread = 8
bx, tx = s[B].split(B.op.axis[0], factor=num_thread)
s[B].bind(bx, tvm.thread_axis("blockIdx.x"))
s[B].bind(tx, tvm.thread_axis("threadIdx.x"))
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
temp = util.tempdir()
f = tvm.build(s, [A, B], device, name="myadd")
path_dso = temp.relpath("dev_lib.so")
f.export_library(path_dso)
f1 = tvm.module.load(path_dso)
a = tvm.nd.array(np.random.uniform(size=1024).astype(A.dtype), ctx)
b = tvm.nd.array(np.zeros(1024, dtype=A.dtype), ctx)
f1(a, b)
np.testing.assert_equal(b.asnumpy(), a.asnumpy() + 1)
check_device("cuda")
check_device("opencl")
check_device("metal")
if __name__ == "__main__": if __name__ == "__main__":
test_device_module_dump()
test_dso_module_load() test_dso_module_load()
...@@ -190,8 +190,7 @@ print(temp.listdir()) ...@@ -190,8 +190,7 @@ print(temp.listdir())
# The CPU(host) module is directly saved as a shared library(so). # The CPU(host) module is directly saved as a shared library(so).
# There can be multiple customed format on the device code. # There can be multiple customed format on the device code.
# In our example, device code is stored in ptx, as well as a meta # In our example, device code is stored in ptx, as well as a meta
# data json file. In the future we can consider pack every binary # data json file. They can be loaded and linked seperatedly via import.
# into one shared library.
# #
###################################################################### ######################################################################
...@@ -208,6 +207,20 @@ fadd1(a, b, c) ...@@ -208,6 +207,20 @@ fadd1(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
###################################################################### ######################################################################
# Pack Everything into One Library
# --------------------------------
# In the above example, we store the device and host code seperatedly.
# TVM also support export everything as one shared library.
# Under the hood, we pack the device modules into binary blobs and link
# them together with the host code.
# Currently we support packing of Metal, OpenCL and CUDA modules.
#
fadd_cuda.export_library(temp.relpath("myadd_pack.so"))
fadd2 = tvm.module.load(temp.relpath("myadd_pack.so"))
fadd2(a, b, c)
np.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy())
######################################################################
# .. note:: Runtime API and Thread-Safety # .. note:: Runtime API and Thread-Safety
# #
# The compiled modules of TVM do not depend on the TVM compiler. # The compiled modules of TVM do not depend on the TVM compiler.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment