Commit 3479b9ab by Lianmin Zheng Committed by Tianqi Chen

[RUNTIME] support limited save without cross compile (#659)

parent db743028
...@@ -35,7 +35,7 @@ runtime::Module BuildMetal(Array<LoweredFunc> funcs) { ...@@ -35,7 +35,7 @@ runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source); return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source);
#else #else
LOG(WARNING) << "Metal runtime not enabled, return a source module..."; LOG(WARNING) << "Metal runtime not enabled, return a source module...";
return SourceModuleCreate(code, "metal"); return DeviceSourceModuleCreate(code, "metal", ExtractFuncInfo(funcs), "metal");
#endif // TVM_METAL_RUNTIME #endif // TVM_METAL_RUNTIME
} }
......
...@@ -27,7 +27,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) { ...@@ -27,7 +27,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs)); return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs));
#else #else
LOG(WARNING) << "OpenCL runtime not enabled, return a source module..."; LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return SourceModuleCreate(code, "cl"); return DeviceSourceModuleCreate(code, "cl", ExtractFuncInfo(funcs), "opencl");
#endif // TVM_OPENCL_RUNTIME #endif // TVM_OPENCL_RUNTIME
} }
......
...@@ -38,7 +38,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) { ...@@ -38,7 +38,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
stream->Write(sz); stream->Write(sz);
for (runtime::Module im : mod->imports()) { for (runtime::Module im : mod->imports()) {
CHECK_EQ(im->imports().size(), 0U) CHECK_EQ(im->imports().size(), 0U)
<< "Only support simply one-level hierachy"; << "Only support simply one-level hierarchy";
std::string tkey = im->type_key(); std::string tkey = im->type_key();
std::string bin; std::string bin;
stream->Write(tkey); stream->Write(tkey);
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
#include "../runtime/meta_data.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -108,6 +109,19 @@ class CodeGenSourceBase { ...@@ -108,6 +109,19 @@ class CodeGenSourceBase {
* \param fmt The code. format. * \param fmt The code. format.
*/ */
runtime::Module SourceModuleCreate(std::string code, std::string fmt); runtime::Module SourceModuleCreate(std::string code, std::string fmt);
/*!
* \brief Create a source module for viewing and limited saving
* \param code The code to be viewed.
* \param fmt The code. format.
* \param fmap The map function information map of each function.
* \param type_key The type_key of the runtime module of this source code
*/
runtime::Module DeviceSourceModuleCreate(
std::string code,
std::string fmt,
std::unordered_map<std::string, runtime::FunctionInfo> fmap,
std::string type_key);
} // namespace codegen } // namespace codegen
} // namespace tvm } // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_ #endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
*/ */
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include "./codegen_source_base.h" #include "./codegen_source_base.h"
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
namespace tvm { namespace tvm {
namespace codegen { namespace codegen {
...@@ -12,8 +14,14 @@ namespace codegen { ...@@ -12,8 +14,14 @@ namespace codegen {
using runtime::TVMArgs; using runtime::TVMArgs;
using runtime::TVMRetValue; using runtime::TVMRetValue;
using runtime::PackedFunc; using runtime::PackedFunc;
using runtime::GetFileFormat;
using runtime::GetMetaFilePath;
using runtime::FunctionInfo;
using runtime::SaveBinaryToFile;
// Simulator function // Simulator function
class SourceModuleNode final : public runtime::ModuleNode { class SourceModuleNode : public runtime::ModuleNode {
public: public:
SourceModuleNode(std::string code, SourceModuleNode(std::string code,
std::string fmt) std::string fmt)
...@@ -21,6 +29,7 @@ class SourceModuleNode final : public runtime::ModuleNode { ...@@ -21,6 +29,7 @@ class SourceModuleNode final : public runtime::ModuleNode {
const char* type_key() const { const char* type_key() const {
return "source"; return "source";
} }
PackedFunc GetFunction( PackedFunc GetFunction(
const std::string& name, const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final { const std::shared_ptr<ModuleNode>& sptr_to_self) final {
...@@ -33,7 +42,7 @@ class SourceModuleNode final : public runtime::ModuleNode { ...@@ -33,7 +42,7 @@ class SourceModuleNode final : public runtime::ModuleNode {
return code_; return code_;
} }
private: protected:
std::string code_; std::string code_;
std::string fmt_; std::string fmt_;
}; };
...@@ -44,6 +53,50 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { ...@@ -44,6 +53,50 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
return runtime::Module(n); return runtime::Module(n);
} }
// supports limited save without cross compile
class DeviceSourceModuleNode final : public SourceModuleNode {
public:
DeviceSourceModuleNode(std::string code,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key)
: SourceModuleNode(code, fmt), fmap_(fmap), type_key_(type_key) {}
const char* type_key() const {
return type_key_.c_str();
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, code_);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(code_);
}
private:
std::unordered_map<std::string, FunctionInfo> fmap_;
std::string type_key_;
};
runtime::Module DeviceSourceModuleCreate(
std::string code,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key) {
std::shared_ptr<DeviceSourceModuleNode> n =
std::make_shared<DeviceSourceModuleNode>(code, fmt, fmap, type_key);
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("module.source_module_create") TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) { .set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SourceModuleCreate(args[0], args[1]); *rv = SourceModuleCreate(args[0], args[1]);
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
namespace tvm { namespace tvm {
namespace runtime { namespace runtime {
/*! /*!
* \brief create a cuda module from data. * \brief create a opencl module from data.
* *
* \param data The module data. * \param data The module data.
* \param fmt The format of the data, can be "clbin", "cl" * \param fmt The format of the data, can be "clbin", "cl"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment