Commit 3479b9ab by Lianmin Zheng Committed by Tianqi Chen

[RUNTIME] support limited save without cross compile (#659)

parent db743028
......@@ -35,7 +35,7 @@ runtime::Module BuildMetal(Array<LoweredFunc> funcs) {
return MetalModuleCreate(code, fmt, ExtractFuncInfo(funcs), source);
#else
LOG(WARNING) << "Metal runtime not enabled, return a source module...";
return SourceModuleCreate(code, "metal");
return DeviceSourceModuleCreate(code, "metal", ExtractFuncInfo(funcs), "metal");
#endif // TVM_METAL_RUNTIME
}
......
......@@ -27,7 +27,7 @@ runtime::Module BuildOpenCL(Array<LoweredFunc> funcs) {
return OpenCLModuleCreate(code, "cl", ExtractFuncInfo(funcs));
#else
LOG(WARNING) << "OpenCL runtime not enabled, return a source module...";
return SourceModuleCreate(code, "cl");
return DeviceSourceModuleCreate(code, "cl", ExtractFuncInfo(funcs), "opencl");
#endif // TVM_OPENCL_RUNTIME
}
......
......@@ -38,7 +38,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
stream->Write(sz);
for (runtime::Module im : mod->imports()) {
CHECK_EQ(im->imports().size(), 0U)
<< "Only support simply one-level hierachy";
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
std::string bin;
stream->Write(tkey);
......
......@@ -11,6 +11,7 @@
#include <string>
#include <vector>
#include <unordered_map>
#include "../runtime/meta_data.h"
namespace tvm {
namespace codegen {
......@@ -108,6 +109,19 @@ class CodeGenSourceBase {
* \param fmt The code. format.
*/
runtime::Module SourceModuleCreate(std::string code, std::string fmt);
/*!
* \brief Create a source module for viewing and limited saving
* \param code The code to be viewed.
* \param fmt The code. format.
* \param fmap The map function information map of each function.
* \param type_key The type_key of the runtime module of this source code
*/
runtime::Module DeviceSourceModuleCreate(
std::string code,
std::string fmt,
std::unordered_map<std::string, runtime::FunctionInfo> fmap,
std::string type_key);
} // namespace codegen
} // namespace tvm
#endif // TVM_CODEGEN_CODEGEN_SOURCE_BASE_H_
......@@ -5,6 +5,8 @@
*/
#include <tvm/runtime/packed_func.h>
#include "./codegen_source_base.h"
#include "../runtime/file_util.h"
#include "../runtime/meta_data.h"
namespace tvm {
namespace codegen {
......@@ -12,8 +14,14 @@ namespace codegen {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;
using runtime::GetFileFormat;
using runtime::GetMetaFilePath;
using runtime::FunctionInfo;
using runtime::SaveBinaryToFile;
// Simulator function
class SourceModuleNode final : public runtime::ModuleNode {
class SourceModuleNode : public runtime::ModuleNode {
public:
SourceModuleNode(std::string code,
std::string fmt)
......@@ -21,6 +29,7 @@ class SourceModuleNode final : public runtime::ModuleNode {
const char* type_key() const {
return "source";
}
PackedFunc GetFunction(
const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final {
......@@ -33,7 +42,7 @@ class SourceModuleNode final : public runtime::ModuleNode {
return code_;
}
private:
protected:
std::string code_;
std::string fmt_;
};
......@@ -44,6 +53,50 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) {
return runtime::Module(n);
}
// supports limited save without cross compile
class DeviceSourceModuleNode final : public SourceModuleNode {
public:
DeviceSourceModuleNode(std::string code,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key)
: SourceModuleNode(code, fmt), fmap_(fmap), type_key_(type_key) {}
const char* type_key() const {
return type_key_.c_str();
}
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
CHECK_EQ(fmt, fmt_)
<< "Can only save to format=" << fmt_;
std::string meta_file = GetMetaFilePath(file_name);
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, code_);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
stream->Write(code_);
}
private:
std::unordered_map<std::string, FunctionInfo> fmap_;
std::string type_key_;
};
runtime::Module DeviceSourceModuleCreate(
std::string code,
std::string fmt,
std::unordered_map<std::string, FunctionInfo> fmap,
std::string type_key) {
std::shared_ptr<DeviceSourceModuleNode> n =
std::make_shared<DeviceSourceModuleNode>(code, fmt, fmap, type_key);
return runtime::Module(n);
}
TVM_REGISTER_GLOBAL("module.source_module_create")
.set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = SourceModuleCreate(args[0], args[1]);
......
......@@ -16,7 +16,7 @@
namespace tvm {
namespace runtime {
/*!
* \brief create a cuda module from data.
* \brief create a opencl module from data.
*
* \param data The module data.
* \param fmt The format of the data, can be "clbin", "cl"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment