Commit d4a51751 by Thomas Viehmann Committed by masahi

ROCm: Add SaveToFile and LoadFile (#3665)

...and add rocm module_save to the tests.
parent 0cecd037
......@@ -71,6 +71,16 @@ class ROCMModuleNode : public runtime::ModuleNode {
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
void SaveToFile(const std::string& file_name,
const std::string& format) final {
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
// note: llvm and asm formats are not laodable, so we don't save them
CHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
SaveMetaDataToFile(meta_file, fmap_);
SaveBinaryToFile(file_name, data_);
}
void SaveToBinary(dmlc::Stream* stream) final {
stream->Write(fmt_);
stream->Write(fmap_);
......@@ -230,6 +240,17 @@ Module ROCMModuleCreate(
return Module(n);
}
Module ROCMModuleLoadFile(const std::string& file_name,
const std::string& format) {
std::string data;
std::unordered_map<std::string, FunctionInfo> fmap;
std::string fmt = GetFileFormat(file_name, format);
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
}
Module ROCMModuleLoadBinary(void* strm) {
dmlc::Stream* stream = static_cast<dmlc::Stream*>(strm);
std::string data;
......@@ -248,5 +269,12 @@ TVM_REGISTER_GLOBAL("module.loadbinary_hsaco")
TVM_REGISTER_GLOBAL("module.loadbinary_hip")
.set_body_typed(ROCMModuleLoadBinary);
TVM_REGISTER_GLOBAL("module.loadfile_hsaco")
.set_body_typed(ROCMModuleLoadFile);
TVM_REGISTER_GLOBAL("module.loadfile_hip")
.set_body_typed(ROCMModuleLoadFile);
} // namespace runtime
} // namespace tvm
......@@ -76,7 +76,12 @@ def test_add_pipeline():
return
if not tvm.module.enabled(host):
return
fmt = "ptx" if device == "cuda" else device
if device == "cuda":
fmt = "ptx"
elif device == "rocm":
fmt = "hsaco"
else:
fmt = device
mhost = tvm.codegen.build_module(fsplits[0], host)
mdev = tvm.codegen.build_module(fsplits[1:], device)
temp = util.tempdir()
......@@ -99,8 +104,9 @@ def test_add_pipeline():
check_module_save("cuda", host="stackvm")
check_target("nvptx", host="llvm")
check_target("vulkan", host="llvm")
check_target("rocm", host="llvm")
check_module_save("vulkan", host="stackvm")
check_target("rocm", host="llvm")
check_module_save("rocm", host="llvm")
if __name__ == "__main__":
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment