Commit f21b5cab by Tianqi Chen

[DOCS] Add save_param_dict, readme (#42)

parent b7b00611
...@@ -3,16 +3,54 @@ ...@@ -3,16 +3,54 @@
[![Build Status](https://travis-ci.org/dmlc/nnvm.svg?branch=master)](https://travis-ci.org/dmlc/nnvm) [![Build Status](https://travis-ci.org/dmlc/nnvm.svg?branch=master)](https://travis-ci.org/dmlc/nnvm)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE) [![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. It provides modules to:
NNVM provides modules to:
- Represent deep learning workloads from front-end frameworks via a graph IR. - Represent deep learning workloads from front-end frameworks via a graph IR.
- Optimize computation graphs to improve performance. - Optimize computation graphs to improve performance.
- Compile into executable modules and deploy to different hardware backends with minimum dependency. - Compile into executable modules and deploy to different hardware backends with minimum dependency.
NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm), which provides an end to end IR compilation stack for deploying deep learning workloads into different hardware backends NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm). NNVM compiler toolchain can target hardware backends supported by TVM.
The compiled module can be deployed to server, mobile, embedded devices and browsers with minimum dependency, in languages including c++, python, javascript, java, objective-c.
The following code snippet demonstrates the general workflow of nnvm compiler toolchain.
```python
import tvm
from tvm.contrib import graph_runtime, rpc
import nnvm.frontend
import nnvm.compiler
# get model from frameworks
# change xyz to supported framework name.
graph, params = nnvm.frontend.from_xyz(...)
# optimize and compile the graph to get a deployable module
# target can be "opencl", "llvm", "metal" or any target supported by tvm
target = "cuda"
graph, lib, params = nnvm.compiler.build(
graph, target, shape={"data", data_shape}, params=params)
# deploy and run on gpu(0)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
module.set_input(**params)
output = tvm.nd.empty(out_shape, ctx=tvm.gpu(0))
for data_array in dataset:
module.set_input("data", data_array)
module.run()
module.get_output(0, output)
# deploy to remote mobile/rasp/browser with minimum tvm rpc runtime
# useful for quick experiments on mobile devices
remote = rpc.connect(remote_host, remote_port)
lib.export_library("mylib.so")
remote.upload("mylib.so")
rlib = rpc.load_module("mylib.so")
# run on remote device
rmodule = graph_runtime.create(graph, rlib, remote.gpu(0))
rmodule.set_input(**params)
rmodule.run()
```
## Links ## Links
- [TinyFlow](https://github.com/tqchen/tinyflow) on how you can use NNVM to build a TensorFlow like API. - [TinyFlow](https://github.com/tqchen/tinyflow) on how you can use NNVM to build a TensorFlow like API.
- [Apache MXNet](http://mxnet.io/) uses NNVM as a backend. - [Apache MXNet](http://mxnet.io/) uses NNVM as a backend.
...@@ -7,6 +7,10 @@ nnvm.compiler ...@@ -7,6 +7,10 @@ nnvm.compiler
.. autofunction:: nnvm.compiler.build_config .. autofunction:: nnvm.compiler.build_config
.. autofunction:: nnvm.compiler.save_param_dict
.. autofunction:: nnvm.compiler.load_param_dict
.. autofunction:: nnvm.compiler.optimize .. autofunction:: nnvm.compiler.optimize
.. automodule:: nnvm.compiler.graph_util .. automodule:: nnvm.compiler.graph_util
......
"""NNVM compiler toolchain. """NNVM compiler toolchain.
User only need to use :any:`build` and :any:`build_config` to do the compilation. User only need to use :any:`build` and :any:`build_config` to do the compilation,
and :any:`save_param_dict` to save the parameters into bytes.
The other APIs are for more advanced interaction with the compiler toolchain. The other APIs are for more advanced interaction with the compiler toolchain.
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -10,6 +11,7 @@ import tvm ...@@ -10,6 +11,7 @@ import tvm
from . import build_module from . import build_module
from . build_module import build, optimize, build_config from . build_module import build, optimize, build_config
from . compile_engine import engine, graph_key from . compile_engine import engine, graph_key
from . param_dict import save_param_dict, load_param_dict
from .. import symbol as _symbol from .. import symbol as _symbol
from .. import graph as _graph from .. import graph as _graph
......
...@@ -9,7 +9,7 @@ from . import graph_attr, graph_util ...@@ -9,7 +9,7 @@ from . import graph_attr, graph_util
from .. import graph as _graph from .. import graph as _graph
OPT_PASS_LEVEL = { OPT_PASS_LEVEL = {
"SimplifyInference": 2, "SimplifyInference": 0,
"PrecomputePrune": 2, "PrecomputePrune": 2,
"OpFusion": 1 "OpFusion": 1
} }
...@@ -26,6 +26,7 @@ class BuildConfig(object): ...@@ -26,6 +26,7 @@ class BuildConfig(object):
current = None current = None
defaults = { defaults = {
"opt_level": 2, "opt_level": 2,
"add_pass": None,
} }
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._old_scope = None self._old_scope = None
...@@ -53,6 +54,23 @@ class BuildConfig(object): ...@@ -53,6 +54,23 @@ class BuildConfig(object):
assert self._old_scope assert self._old_scope
BuildConfig.current = self._old_scope BuildConfig.current = self._old_scope
def pass_enabled(self, pass_name):
"""Get whether pass is enabled.
Parameters
----------
pass_name : str
The optimization pass name
Returns
-------
enabled : bool
Whether pass is enabled.
"""
if self.add_pass and pass_name in self.add_pass:
return True
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
BuildConfig.current = BuildConfig() BuildConfig.current = BuildConfig()
...@@ -64,6 +82,9 @@ def build_config(**kwargs): ...@@ -64,6 +82,9 @@ def build_config(**kwargs):
opt_level: int, default=2 opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass. Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
Returns Returns
------- -------
config: BuildConfig config: BuildConfig
...@@ -120,7 +141,7 @@ def optimize(graph, shape, dtype="float32"): ...@@ -120,7 +141,7 @@ def optimize(graph, shape, dtype="float32"):
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
cfg = BuildConfig.current cfg = BuildConfig.current
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]: if cfg.pass_enabled("SimplifyInference"):
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"]) graph = graph.apply(["InferShape", "SimplifyInference"])
return graph return graph
...@@ -182,14 +203,17 @@ def build(graph, target, shape, dtype="float32", params=None): ...@@ -182,14 +203,17 @@ def build(graph, target, shape, dtype="float32", params=None):
# Apply optimization # Apply optimization
graph = optimize(graph, shape, dtype) graph = optimize(graph, shape, dtype)
# Precompute prune # Precompute prune
if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]: if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params) graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params) shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generatiom # Operator Fusion and generatiom
graph = graph_attr.set_shape_inputs(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype) graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str") graph._set_json_attr("target", target, "str")
graph._set_json_attr("opt_level", cfg.opt_level, "int") if cfg.pass_enabled("OpFusion"):
graph._set_json_attr("opt_level", 1, "int")
else:
graph._set_json_attr("opt_level", 0, "int")
graph = graph.apply("InferShape").apply("InferType") graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile") graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module") libmod = graph_attr._move_out_module(graph, "module")
......
"""Helper utility to save parameter dict"""
import tvm
_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")
def save_param_dict(params):
"""Save parameter dictionary to binary bytes.
The result binary bytes can be loaded by the
GraphModule with API "load_params".
Parameters
----------
params : dict of str to NDArray
The parameter dictionary.
Returns
-------
param_bytes: bytearray
Serialized parameters.
Examples
--------
.. code-block:: python
# compile and save the modules to file.
graph, lib, params = nnvm.compiler.build(
graph, target, shape={"data", data_shape}, params=params)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
# save the parameters as byte array
param_bytes = nnvm.compiler.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
module["load_params"](param_bytes)
"""
args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_param_dict(*args)
def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.
Parameters
----------
param_bytes: bytearray
Serialized parameters.
Returns
-------
params : dict of str to NDArray
The parameter dictionary.
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
load_mod = _load_param_dict(param_bytes)
size = load_mod(0)
param_dict = {}
for i in range(size):
key = load_mod(1, i)
dltensor_handle = load_mod(2, i)
param_dict[key] = tvm.nd.NDArray(dltensor_handle, False)
return param_dict
...@@ -12,7 +12,7 @@ from ._base import _LIB ...@@ -12,7 +12,7 @@ from ._base import _LIB
from ._base import c_array, c_str, nn_uint, py_str, string_types from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle from ._base import GraphHandle, SymbolHandle
from ._base import check_call from ._base import check_call
from .symbol import Symbol, Group as _Group from .symbol import Variable, Symbol, Group as _Group
class GraphIndex(object): class GraphIndex(object):
"""Index for quickly accessing graph attributes. """Index for quickly accessing graph attributes.
...@@ -174,9 +174,19 @@ class Graph(object): ...@@ -174,9 +174,19 @@ class Graph(object):
check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle))) check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle)))
return Symbol(shandle) return Symbol(shandle)
def json(self):
"""Get JSON representation of the graph
Returns
-------
json : str
JSON representation of the graph
"""
return self.apply("SaveJSON").json_attr("json")
def _tvm_graph_json(self): def _tvm_graph_json(self):
"""Get TVM graph json""" """Get TVM graph json"""
return self.apply("SaveJSON").json_attr("json") return self.json()
@property @property
def index(self): def index(self):
...@@ -225,6 +235,24 @@ class Graph(object): ...@@ -225,6 +235,24 @@ class Graph(object):
return Graph(ghandle) return Graph(ghandle)
def load_json(json_str):
"""Create a new graph by loading from json
Parameters
----------
json_str : str
The json string
Returns
-------
graph : Graph
The loaded graph
"""
ret = create(Variable("x"))
ret._set_json_attr("json", json_str)
return ret.apply("LoadJSON")
def create(symbol): def create(symbol):
"""Create a new graph from symbol. """Create a new graph from symbol.
......
...@@ -15,46 +15,10 @@ ...@@ -15,46 +15,10 @@
#include <tvm/lowered_func.h> #include <tvm/lowered_func.h>
#include <dmlc/parameter.h> #include <dmlc/parameter.h>
#include "./compile_engine.h" #include "./compile_engine.h"
#include "../../tvm/src/runtime/graph/graph_runtime.h" #include "./graph_runtime.h"
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
// parser
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
TVMOpParam param;
param.Init(attrs->dict);
attrs->parsed = std::move(param);
}
NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(TVMOpParamParser)
.set_num_inputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_inputs;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
using namespace tvm; using namespace tvm;
// The single fuse rule. // The single fuse rule.
......
/*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc
* \brief Interface code with TVM graph runtime.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include "./graph_runtime.h"
namespace nnvm {
namespace compiler {
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::kTVMNDArrayMagic;
using tvm::runtime::kTVMNDArrayListMagic;
DMLC_REGISTER_PARAMETER(TVMOpParam);
// parser
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
TVMOpParam param;
param.Init(attrs->dict);
attrs->parsed = std::move(param);
}
NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(TVMOpParamParser)
.set_num_inputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_inputs;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header));
strm->Write(&reserved, sizeof(reserved));
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
int type_size = tensor->dtype.bits / 8;
int64_t size = 1;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
}
int64_t data_byte_size = type_size * size;
strm->Write(&data_byte_size, sizeof(data_byte_size));
strm->Write(tensor->data, data_byte_size);
return true;
}
DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
<< "Invalid DLTensor file format";
DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim,
tensor.dtype.code,
tensor.dtype.bits,
tensor.dtype.lanes,
static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id,
&ret), 0) << TVMGetLastError();
int64_t size = 1;
int type_size = ret->dtype.bits / 8;
for (int i = 0; i < ret->ndim; ++i) {
size *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, type_size * size))
<< "Invalid DLTensor file format";
return ret;
}
TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
size_t num_params = args.size() / 2;
std::vector<std::string> names;
names.reserve(num_params);
std::vector<DLTensor*> arrays;
arrays.reserve(num_params);
for (size_t i = 0; i < num_params * 2; i += 2) {
names.emplace_back(args[i].operator std::string());
arrays.emplace_back(args[i + 1].operator DLTensor*());
}
std::string bytes;
dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm;
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header));
fo->Write(&reserved, sizeof(reserved));
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz));
for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]);
}
}
TVMByteArray arr;
arr.data = bytes.c_str();
arr.size = bytes.length();
*rv = arr;
});
TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string bytes = args[0];
std::vector<DLTensor*> data;
std::vector<std::string> names;
dmlc::MemoryStringStream memstrm(&bytes);
dmlc::Stream* strm = &memstrm;
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
data.push_back(LoadDLTensor(strm));
}
auto packed = [data, names](TVMArgs args, TVMRetValue* rv) {
int code = args[0];
if (code == 0) {
*rv = static_cast<int64_t>(data.size());
} else if (code == 1) {
int index = args[1];
*rv = names[index];
} else {
CHECK_EQ(code, 2);
int index = args[1];
*rv = static_cast<void*>(data[index]);
}
};
*rv = PackedFunc(packed);
});
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.h
* \brief Interface code with TVM graph runtime.
*/
#ifndef NNVM_COMPILER_GRAPH_RUNTIME_H_
#define NNVM_COMPILER_GRAPH_RUNTIME_H_
#include <nnvm/graph.h>
#include <vector>
#include "../../tvm/src/runtime/graph/graph_runtime.h"
namespace nnvm {
namespace compiler {
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_GRAPH_RUNTIME_H_
...@@ -70,7 +70,8 @@ def test_precompute_prune(): ...@@ -70,7 +70,8 @@ def test_precompute_prune():
m = graph_runtime.create(graph, lib, tvm.cpu(0)) m = graph_runtime.create(graph, lib, tvm.cpu(0))
params["y"] = ny params["y"] = ny
res = tvm.nd.empty(shape) res = tvm.nd.empty(shape)
m.run(**params) m["load_params"](nnvm.compiler.save_param_dict(params))
m.run()
out = m.get_output(0, out=res) out = m.get_output(0, out=res)
np.testing.assert_allclose( np.testing.assert_allclose(
res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy()) res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy())
......
import numpy as np
import nnvm.compiler
def test_save_load():
x = np.random.uniform(size=(10, 2)).astype("float32")
y = np.random.uniform(size=(1, 2, 3)).astype("float32")
x[:] = 1
y[:] = 1
params = {"x": x, "y": y}
param_bytes = nnvm.compiler.save_param_dict(params)
assert isinstance(param_bytes, bytearray)
param2 = nnvm.compiler.load_param_dict(param_bytes)
assert len(param2) == 2
np.testing.assert_equal(param2["x"].asnumpy(), x)
np.testing.assert_equal(param2["y"].asnumpy(), y)
if __name__ == "__main__":
test_save_load()
...@@ -10,6 +10,9 @@ def test_json_pass(): ...@@ -10,6 +10,9 @@ def test_json_pass():
ret._set_json_attr('json', ret.json_attr('json')) ret._set_json_attr('json', ret.json_attr('json'))
g2 = ret.apply('LoadJSON') g2 = ret.apply('LoadJSON')
assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json') assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')
json = g.json()
g2 = graph.load_json(json)
assert json == g2.json()
def test_json_pass_with_attr(): def test_json_pass_with_attr():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment