Commit f4d1dddb by ziheng Committed by GitHub

[EXECUTOR] Save/Load Params (#242)

* [EXECUTOR] Save/Load Params

* [EXECUTOR] Improve Save/Load, fix Makefile

* [EXECUTOR] Make save independent with executor
parent 9c36d9f0
# Minimum Makefile for the extension package # Minimum Makefile for the extension package
TVM_ROOT=$(shell cd ../..; pwd) TVM_ROOT=$(shell cd ../..; pwd)
NNVM_PATH=nnvm NNVM_PATH=nnvm
DMLC_CORE=${TVM_ROOT}/dmlc-core
PKG_CFLAGS = -std=c++11 -O2 -fPIC\ PKG_CFLAGS = -std=c++11 -O2 -fPIC\
-I${TVM_ROOT}/include\ -I${TVM_ROOT}/include\
-I${TVM_ROOT}/dmlc-core/include\ -I${DMLC_CORE}/include\
-I${TVM_ROOT}/dlpack/include\ -I${TVM_ROOT}/dlpack/include\
-I${TVM_ROOT}/HalideIR/src -I${TVM_ROOT}/HalideIR/src
...@@ -23,10 +24,12 @@ endif ...@@ -23,10 +24,12 @@ endif
NNVM_CONTRIB_SRC = $(wildcard src/*.cc) NNVM_CONTRIB_SRC = $(wildcard src/*.cc)
NNVM_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(NNVM_CONTRIB_SRC)) NNVM_CONTRIB_OBJ = $(patsubst src/%.cc, build/%.o, $(NNVM_CONTRIB_SRC))
include $(DMLC_CORE)/make/dmlc.mk
ALL_DEP = $(NNVM_CONTRIB_OBJ) ALL_DEP = $(NNVM_CONTRIB_OBJ)
PKG_CFLAGS += -I${NNVM_PATH}/include PKG_CFLAGS += -I${NNVM_PATH}/include
ALL_DEP += ${NNVM_PATH}/lib/libnnvm.a ALL_DEP += ${DMLC_CORE}/libdmlc.a ${NNVM_PATH}/lib/libnnvm.a
.PHONY: clean all .PHONY: clean all
...@@ -38,6 +41,8 @@ nnvm: ...@@ -38,6 +41,8 @@ nnvm:
nnvm/lib/libnnvm.a: | nnvm nnvm/lib/libnnvm.a: | nnvm
+ cd nnvm; make ; cd - + cd nnvm; make ; cd -
$(DMLC_CORE)/libdmlc.a:
+ cd $(DMLC_CORE); make libdmlc.a; cd $(TVM_ROOT)
build/%.o: src/%.cc | nnvm build/%.o: src/%.cc | nnvm
@mkdir -p $(@D) @mkdir -p $(@D)
......
...@@ -4,6 +4,6 @@ import tvm ...@@ -4,6 +4,6 @@ import tvm
from . import _base from . import _base
from nnvm.symbol import * from nnvm.symbol import *
from . import op_tvm_def from . import op_tvm_def
from .build import build, bind from .build import build, bind, save_params
...@@ -42,5 +42,16 @@ def _lower(sch, inputs, func_name): ...@@ -42,5 +42,16 @@ def _lower(sch, inputs, func_name):
@tvm.register_func("tvm_graph.build_target") @tvm.register_func("tvm_graph.build_target")
def _build(funcs, target): def _build(funcs, target):
return tvm.build(funcs, target) return tvm.build(funcs, target=target)
_save_param_dict = tvm.get_global_func("tvm_graph._save_param_dict")
def save_params(fname, params):
args = []
args.append(fname)
args.append(len(params))
for kv in params.items():
args.append(kv[0])
args.append(kv[1])
_save_param_dict(*args)
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file NNVM Graph executor. * \file NNVM Graph executor.
*/ */
#include <dmlc/io.h>
#include <tvm/runtime/registry.h> #include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h> #include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h> #include <tvm/runtime/module.h>
...@@ -52,6 +53,8 @@ class GraphExecutor : public runtime::ModuleNode { ...@@ -52,6 +53,8 @@ class GraphExecutor : public runtime::ModuleNode {
void SetInput(int index, DLTensor* data_in); void SetInput(int index, DLTensor* data_in);
// Copy index-th output to data_out // Copy index-th output to data_out
void GetOutput(int index, DLTensor* data_out); void GetOutput(int index, DLTensor* data_out);
// Load parameters from file
void LoadParams(std::string fname);
// Execute the graph. // Execute the graph.
void Run(); void Run();
...@@ -93,6 +96,10 @@ PackedFunc GraphExecutor::GetFunction( ...@@ -93,6 +96,10 @@ PackedFunc GraphExecutor::GetFunction(
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Run(); this->Run();
}); });
} else if (name == "load_params") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->LoadParams(args[0]);
});
} else { } else {
return PackedFunc(); return PackedFunc();
} }
...@@ -133,6 +140,138 @@ void GraphExecutor::GetOutput(int index, DLTensor* data_out) { ...@@ -133,6 +140,138 @@ void GraphExecutor::GetOutput(int index, DLTensor* data_out) {
TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr)); TVM_CCALL(TVMArrayCopyFromTo(&data_entry_[eid], data_out, nullptr));
} }
constexpr uint64_t kTVMNDArrayMagic = 0xDD5E40F096B4A13F;
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header));
strm->Write(&reserved, sizeof(reserved));
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
int type_size = tensor->dtype.bits / 8;
int64_t size = 1;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
}
int64_t data_byte_size = type_size * size;
strm->Write(&data_byte_size, sizeof(data_byte_size));
strm->Write(tensor->data, data_byte_size);
return true;
}
bool LoadDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->ctx, sizeof(tensor->ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->ndim, sizeof(tensor->ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor->dtype, sizeof(tensor->dtype)))
<< "Invalid DLTensor file format";
int ndim = tensor->ndim;
CHECK(strm->Read(tensor->shape, sizeof(int64_t) * ndim))
<< "Invalid DLTensor file format";
int64_t size = 1;
int type_size = tensor->dtype.bits / 8;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
<< "Invalid DLTensor file format";
CHECK(strm->Read(tensor->data, type_size * size))
<< "Invalid DLTensor file format";
return true;
}
constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7;
TVM_REGISTER_GLOBAL("tvm_graph._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string fname = args[0];
int num_params = args[1];
std::vector<std::string> names;
names.reserve(num_params);
std::vector<DLTensor*> arrays;
arrays.reserve(num_params);
for (int i = 2; i < (2 + 2*num_params); i += 2) {
names.emplace_back(args[i].operator std::string());
arrays.emplace_back(args[i+1].operator DLTensor*());
}
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(fname.c_str(), "w"));
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header));
fo->Write(&reserved, sizeof(reserved));
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz));
for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo.get(), arrays[i]);
}
}
});
void GraphExecutor::LoadParams(std::string fname) {
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname.c_str(), "r"));
uint64_t header, reserved;
CHECK(fi->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(fi->Read(&reserved))
<< "Invalid parameters file format";
std::vector<std::string> names;
CHECK(fi->Read(&names))
<< "Invalid parameters file format";
nnvm::Symbol s;
s.outputs = graph_.outputs;
std::vector<std::string> input_names =
s.ListInputNames(nnvm::Symbol::ListInputOption::kAll);
std::unordered_map<std::string, size_t> name_index;
for (size_t i = 0; i < input_names.size(); ++i) {
name_index.emplace(input_names[i], i);
}
{
uint64_t sz;
fi->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
size_t idx = name_index.at(names[i]);
CHECK(LoadDLTensor(fi.get(), &data_entry_[idx]))
<< "Invalid parameters file format";
}
}
}
void GraphExecutor::SetupStorage() { void GraphExecutor::SetupStorage() {
const auto& idx = graph_.indexed_graph(); const auto& idx = graph_.indexed_graph();
// Grab saved optimization plan from graph. // Grab saved optimization plan from graph.
......
import tvm_graph as tg
import numpy as np
import tvm
def test_save_load():
shape = (10, 128)
dtype = tvm.float32
na = tvm.nd.array(np.ones(shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype))
x = tg.Variable('x')
y = tg.Variable('y')
z = tg.exp(y + x)
g = tg.build(z, "llvm", shape={'x': shape, 'y': shape})
m0 = tg.bind(g, tvm.cpu(0))
set_input0, run0, get_output0 = m0['set_input'], m0['run'], m0['get_output']
set_input0(0, na)
set_input0(1, nb)
run0()
out0 = tvm.nd.array(np.zeros(shape).astype(dtype))
get_output0(0, out0)
tg.save_params('test.params', {'x': na, 'y': nb})
# create another executor
m1 = tg.bind(g, tvm.cpu(0))
load_params1 = m1['load_params']
load_params1('test.params')
run1, get_output1 = m1['run'], m1['get_output']
run1()
out1 = tvm.nd.array(np.zeros(shape).astype(dtype))
get_output1(0, out1)
np.testing.assert_allclose(out0.asnumpy(), out1.asnumpy())
if __name__ == "__main__":
test_save_load()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment