Commit f21b5cab by Tianqi Chen

[DOCS] Add save_param_dict, readme (#42)

parent b7b00611
......@@ -3,16 +3,54 @@
[![Build Status](https://travis-ci.org/dmlc/nnvm.svg?branch=master)](https://travis-ci.org/dmlc/nnvm)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
NNVM is a reusable computational graph optimization and compilation stack for deep learning systems.
NNVM provides modules to:
NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. It provides modules to:
- Represent deep learning workloads from front-end frameworks via a graph IR.
- Optimize computation graphs to improve performance.
- Compile into executable modules and deploy to different hardware backends with minimum dependency.
NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm), which provides an end to end IR compilation stack for deploying deep learning workloads into different hardware backends
NNVM is designed to add new frontend, operators and graph optimizations in a decentralized fashion without changing the core interface. NNVM is part of [TVM stack](https://github.com/dmlc/tvm). NNVM compiler toolchain can target hardware backends supported by TVM.
The compiled module can be deployed to server, mobile, embedded devices and browsers with minimum dependency, in languages including c++, python, javascript, java, objective-c.
The following code snippet demonstrates the general workflow of nnvm compiler toolchain.
```python
import tvm
from tvm.contrib import graph_runtime, rpc
import nnvm.frontend
import nnvm.compiler
# get model from frameworks
# change xyz to supported framework name.
graph, params = nnvm.frontend.from_xyz(...)
# optimize and compile the graph to get a deployable module
# target can be "opencl", "llvm", "metal" or any target supported by tvm
target = "cuda"
graph, lib, params = nnvm.compiler.build(
graph, target, shape={"data", data_shape}, params=params)
# deploy and run on gpu(0)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
module.set_input(**params)
output = tvm.nd.empty(out_shape, ctx=tvm.gpu(0))
for data_array in dataset:
module.set_input("data", data_array)
module.run()
module.get_output(0, output)
# deploy to remote mobile/rasp/browser with minimum tvm rpc runtime
# useful for quick experiments on mobile devices
remote = rpc.connect(remote_host, remote_port)
lib.export_library("mylib.so")
remote.upload("mylib.so")
rlib = rpc.load_module("mylib.so")
# run on remote device
rmodule = graph_runtime.create(graph, rlib, remote.gpu(0))
rmodule.set_input(**params)
rmodule.run()
```
## Links
- [TinyFlow](https://github.com/tqchen/tinyflow) on how you can use NNVM to build a TensorFlow like API.
- [Apache MXNet](http://mxnet.io/) uses NNVM as a backend.
......@@ -7,6 +7,10 @@ nnvm.compiler
.. autofunction:: nnvm.compiler.build_config
.. autofunction:: nnvm.compiler.save_param_dict
.. autofunction:: nnvm.compiler.load_param_dict
.. autofunction:: nnvm.compiler.optimize
.. automodule:: nnvm.compiler.graph_util
......
"""NNVM compiler toolchain.
User only need to use :any:`build` and :any:`build_config` to do the compilation.
User only need to use :any:`build` and :any:`build_config` to do the compilation,
and :any:`save_param_dict` to save the parameters into bytes.
The other APIs are for more advanced interaction with the compiler toolchain.
"""
from __future__ import absolute_import
......@@ -10,6 +11,7 @@ import tvm
from . import build_module
from . build_module import build, optimize, build_config
from . compile_engine import engine, graph_key
from . param_dict import save_param_dict, load_param_dict
from .. import symbol as _symbol
from .. import graph as _graph
......
......@@ -9,7 +9,7 @@ from . import graph_attr, graph_util
from .. import graph as _graph
OPT_PASS_LEVEL = {
"SimplifyInference": 2,
"SimplifyInference": 0,
"PrecomputePrune": 2,
"OpFusion": 1
}
......@@ -26,6 +26,7 @@ class BuildConfig(object):
current = None
defaults = {
"opt_level": 2,
"add_pass": None,
}
def __init__(self, **kwargs):
self._old_scope = None
......@@ -53,6 +54,23 @@ class BuildConfig(object):
assert self._old_scope
BuildConfig.current = self._old_scope
def pass_enabled(self, pass_name):
"""Get whether pass is enabled.
Parameters
----------
pass_name : str
The optimization pass name
Returns
-------
enabled : bool
Whether pass is enabled.
"""
if self.add_pass and pass_name in self.add_pass:
return True
return self.opt_level >= OPT_PASS_LEVEL[pass_name]
BuildConfig.current = BuildConfig()
......@@ -64,6 +82,9 @@ def build_config(**kwargs):
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
add_pass: set of str
Optimization pass to be added regardless of optimization level.
Returns
-------
config: BuildConfig
......@@ -120,7 +141,7 @@ def optimize(graph, shape, dtype="float32"):
"""
# pylint: disable=unused-argument
cfg = BuildConfig.current
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]:
if cfg.pass_enabled("SimplifyInference"):
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"])
return graph
......@@ -182,14 +203,17 @@ def build(graph, target, shape, dtype="float32", params=None):
# Apply optimization
graph = optimize(graph, shape, dtype)
# Precompute prune
if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]:
if params and cfg.pass_enabled("PrecomputePrune"):
graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generatiom
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str")
graph._set_json_attr("opt_level", cfg.opt_level, "int")
if cfg.pass_enabled("OpFusion"):
graph._set_json_attr("opt_level", 1, "int")
else:
graph._set_json_attr("opt_level", 0, "int")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
libmod = graph_attr._move_out_module(graph, "module")
......
"""Helper utility to save parameter dict"""
import tvm
_save_param_dict = tvm.get_global_func("nnvm.compiler._save_param_dict")
_load_param_dict = tvm.get_global_func("nnvm.compiler._load_param_dict")
def save_param_dict(params):
"""Save parameter dictionary to binary bytes.
The result binary bytes can be loaded by the
GraphModule with API "load_params".
Parameters
----------
params : dict of str to NDArray
The parameter dictionary.
Returns
-------
param_bytes: bytearray
Serialized parameters.
Examples
--------
.. code-block:: python
# compile and save the modules to file.
graph, lib, params = nnvm.compiler.build(
graph, target, shape={"data", data_shape}, params=params)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
# save the parameters as byte array
param_bytes = nnvm.compiler.save_param_dict(params)
# We can serialize the param_bytes and load it back later.
# Pass in byte array to module to directly set parameters
module["load_params"](param_bytes)
"""
args = []
for k, v in params.items():
args.append(k)
args.append(tvm.nd.array(v))
return _save_param_dict(*args)
def load_param_dict(param_bytes):
"""Load parameter dictionary to binary bytes.
Parameters
----------
param_bytes: bytearray
Serialized parameters.
Returns
-------
params : dict of str to NDArray
The parameter dictionary.
"""
if isinstance(param_bytes, (bytes, str)):
param_bytes = bytearray(param_bytes)
load_mod = _load_param_dict(param_bytes)
size = load_mod(0)
param_dict = {}
for i in range(size):
key = load_mod(1, i)
dltensor_handle = load_mod(2, i)
param_dict[key] = tvm.nd.NDArray(dltensor_handle, False)
return param_dict
......@@ -12,7 +12,7 @@ from ._base import _LIB
from ._base import c_array, c_str, nn_uint, py_str, string_types
from ._base import GraphHandle, SymbolHandle
from ._base import check_call
from .symbol import Symbol, Group as _Group
from .symbol import Variable, Symbol, Group as _Group
class GraphIndex(object):
"""Index for quickly accessing graph attributes.
......@@ -174,9 +174,19 @@ class Graph(object):
check_call(_LIB.NNGraphGetSymbol(self.handle, ctypes.byref(shandle)))
return Symbol(shandle)
def json(self):
"""Get JSON representation of the graph
Returns
-------
json : str
JSON representation of the graph
"""
return self.apply("SaveJSON").json_attr("json")
def _tvm_graph_json(self):
"""Get TVM graph json"""
return self.apply("SaveJSON").json_attr("json")
return self.json()
@property
def index(self):
......@@ -225,6 +235,24 @@ class Graph(object):
return Graph(ghandle)
def load_json(json_str):
"""Create a new graph by loading from json
Parameters
----------
json_str : str
The json string
Returns
-------
graph : Graph
The loaded graph
"""
ret = create(Variable("x"))
ret._set_json_attr("json", json_str)
return ret.apply("LoadJSON")
def create(symbol):
"""Create a new graph from symbol.
......
......@@ -15,46 +15,10 @@
#include <tvm/lowered_func.h>
#include <dmlc/parameter.h>
#include "./compile_engine.h"
#include "../../tvm/src/runtime/graph/graph_runtime.h"
#include "./graph_runtime.h"
namespace nnvm {
namespace compiler {
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};
DMLC_REGISTER_PARAMETER(TVMOpParam);
// parser
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
TVMOpParam param;
param.Init(attrs->dict);
attrs->parsed = std::move(param);
}
NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(TVMOpParamParser)
.set_num_inputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_inputs;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
using namespace tvm;
// The single fuse rule.
......
/*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.cc
* \brief Interface code with TVM graph runtime.
*/
#include <dmlc/memory_io.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/registry.h>
#include "./graph_runtime.h"
namespace nnvm {
namespace compiler {
using tvm::runtime::TVMArgs;
using tvm::runtime::TVMRetValue;
using tvm::runtime::PackedFunc;
using tvm::runtime::kTVMNDArrayMagic;
using tvm::runtime::kTVMNDArrayListMagic;
DMLC_REGISTER_PARAMETER(TVMOpParam);
// parser
inline void TVMOpParamParser(nnvm::NodeAttrs* attrs) {
TVMOpParam param;
param.Init(attrs->dict);
attrs->parsed = std::move(param);
}
NNVM_REGISTER_OP(tvm_op)
.set_attr_parser(TVMOpParamParser)
.set_num_inputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_inputs;
})
.set_num_outputs([](const NodeAttrs& attrs) {
const TVMOpParam& param = nnvm::get<TVMOpParam>(attrs.parsed);
return param.num_outputs;
});
bool SaveDLTensor(dmlc::Stream* strm, DLTensor* tensor) {
uint64_t header = kTVMNDArrayMagic, reserved = 0;
strm->Write(&header, sizeof(header));
strm->Write(&reserved, sizeof(reserved));
strm->Write(&tensor->ctx, sizeof(tensor->ctx));
strm->Write(&tensor->ndim, sizeof(tensor->ndim));
strm->Write(&tensor->dtype, sizeof(tensor->dtype));
int ndim = tensor->ndim;
strm->Write(tensor->shape, sizeof(int64_t) * ndim);
int type_size = tensor->dtype.bits / 8;
int64_t size = 1;
for (int i = 0; i < ndim; ++i) {
size *= tensor->shape[i];
}
int64_t data_byte_size = type_size * size;
strm->Write(&data_byte_size, sizeof(data_byte_size));
strm->Write(tensor->data, data_byte_size);
return true;
}
DLTensor* LoadDLTensor(dmlc::Stream* strm) {
uint64_t header, reserved;
CHECK(strm->Read(&header, sizeof(header)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&reserved, sizeof(reserved)))
<< "Invalid DLTensor file format";
CHECK(header == kTVMNDArrayMagic)
<< "Invalid DLTensor file format";
DLTensor tensor;
CHECK(strm->Read(&tensor.ctx, sizeof(tensor.ctx)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.ndim, sizeof(tensor.ndim)))
<< "Invalid DLTensor file format";
CHECK(strm->Read(&tensor.dtype, sizeof(tensor.dtype)))
<< "Invalid DLTensor file format";
std::vector<int64_t> shape(tensor.ndim);
CHECK(strm->Read(&shape[0], sizeof(int64_t) * tensor.ndim))
<< "Invalid DLTensor file format";
DLTensor* ret;
CHECK_EQ(TVMArrayAlloc(shape.data(),
tensor.ndim,
tensor.dtype.code,
tensor.dtype.bits,
tensor.dtype.lanes,
static_cast<int>(tensor.ctx.device_type),
tensor.ctx.device_id,
&ret), 0) << TVMGetLastError();
int64_t size = 1;
int type_size = ret->dtype.bits / 8;
for (int i = 0; i < ret->ndim; ++i) {
size *= ret->shape[i];
}
int64_t data_byte_size;
CHECK(strm->Read(&data_byte_size, sizeof(data_byte_size)))
<< "Invalid DLTensor file format";
CHECK(data_byte_size == type_size * size)
<< "Invalid DLTensor file format";
CHECK(strm->Read(ret->data, type_size * size))
<< "Invalid DLTensor file format";
return ret;
}
TVM_REGISTER_GLOBAL("nnvm.compiler._save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
size_t num_params = args.size() / 2;
std::vector<std::string> names;
names.reserve(num_params);
std::vector<DLTensor*> arrays;
arrays.reserve(num_params);
for (size_t i = 0; i < num_params * 2; i += 2) {
names.emplace_back(args[i].operator std::string());
arrays.emplace_back(args[i + 1].operator DLTensor*());
}
std::string bytes;
dmlc::MemoryStringStream strm(&bytes);
dmlc::Stream* fo = &strm;
uint64_t header = kTVMNDArrayListMagic, reserved = 0;
fo->Write(&header, sizeof(header));
fo->Write(&reserved, sizeof(reserved));
fo->Write(names);
{
uint64_t sz = static_cast<uint64_t>(arrays.size());
fo->Write(&sz, sizeof(sz));
for (size_t i = 0; i < sz; ++i) {
SaveDLTensor(fo, arrays[i]);
}
}
TVMByteArray arr;
arr.data = bytes.c_str();
arr.size = bytes.length();
*rv = arr;
});
TVM_REGISTER_GLOBAL("nnvm.compiler._load_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
std::string bytes = args[0];
std::vector<DLTensor*> data;
std::vector<std::string> names;
dmlc::MemoryStringStream memstrm(&bytes);
dmlc::Stream* strm = &memstrm;
uint64_t header, reserved;
CHECK(strm->Read(&header))
<< "Invalid parameters file format";
CHECK(header == kTVMNDArrayListMagic)
<< "Invalid parameters file format";
CHECK(strm->Read(&reserved))
<< "Invalid parameters file format";
CHECK(strm->Read(&names))
<< "Invalid parameters file format";
uint64_t sz;
strm->Read(&sz, sizeof(sz));
size_t size = static_cast<size_t>(sz);
CHECK(size == names.size())
<< "Invalid parameters file format";
for (size_t i = 0; i < size; ++i) {
data.push_back(LoadDLTensor(strm));
}
auto packed = [data, names](TVMArgs args, TVMRetValue* rv) {
int code = args[0];
if (code == 0) {
*rv = static_cast<int64_t>(data.size());
} else if (code == 1) {
int index = args[1];
*rv = names[index];
} else {
CHECK_EQ(code, 2);
int index = args[1];
*rv = static_cast<void*>(data[index]);
}
};
*rv = PackedFunc(packed);
});
} // namespace compiler
} // namespace nnvm
/*!
* Copyright (c) 2017 by Contributors
* \file graph_runtime.h
* \brief Interface code with TVM graph runtime.
*/
#ifndef NNVM_COMPILER_GRAPH_RUNTIME_H_
#define NNVM_COMPILER_GRAPH_RUNTIME_H_
#include <nnvm/graph.h>
#include <vector>
#include "../../tvm/src/runtime/graph/graph_runtime.h"
namespace nnvm {
namespace compiler {
struct TVMOpParam : public dmlc::Parameter<TVMOpParam> {
std::string func_name;
uint32_t num_inputs;
uint32_t num_outputs;
uint32_t flatten_data;
DMLC_DECLARE_PARAMETER(TVMOpParam) {
DMLC_DECLARE_FIELD(func_name);
DMLC_DECLARE_FIELD(num_inputs).set_default(1);
DMLC_DECLARE_FIELD(num_outputs).set_default(1);
DMLC_DECLARE_FIELD(flatten_data).set_default(0);
}
};
} // namespace compiler
} // namespace nnvm
#endif // NNVM_COMPILER_GRAPH_RUNTIME_H_
......@@ -70,7 +70,8 @@ def test_precompute_prune():
m = graph_runtime.create(graph, lib, tvm.cpu(0))
params["y"] = ny
res = tvm.nd.empty(shape)
m.run(**params)
m["load_params"](nnvm.compiler.save_param_dict(params))
m.run()
out = m.get_output(0, out=res)
np.testing.assert_allclose(
res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy() + na.asnumpy())
......
import numpy as np
import nnvm.compiler
def test_save_load():
x = np.random.uniform(size=(10, 2)).astype("float32")
y = np.random.uniform(size=(1, 2, 3)).astype("float32")
x[:] = 1
y[:] = 1
params = {"x": x, "y": y}
param_bytes = nnvm.compiler.save_param_dict(params)
assert isinstance(param_bytes, bytearray)
param2 = nnvm.compiler.load_param_dict(param_bytes)
assert len(param2) == 2
np.testing.assert_equal(param2["x"].asnumpy(), x)
np.testing.assert_equal(param2["y"].asnumpy(), y)
if __name__ == "__main__":
test_save_load()
......@@ -10,6 +10,9 @@ def test_json_pass():
ret._set_json_attr('json', ret.json_attr('json'))
g2 = ret.apply('LoadJSON')
assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')
json = g.json()
g2 = graph.load_json(json)
assert json == g2.json()
def test_json_pass_with_attr():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment