Commit a2ab3d83 by Tianqi Chen

[RUNTIME][COMPILER] Formal compiler pipeline, runtime wrapper module (#21)

* [RUNTIME][COMPILER] Formal compiler pipeline, runtime wrapper module

* more detailed comments
parent 79ceb9f7
......@@ -4,7 +4,7 @@ from __future__ import absolute_import
import tvm
from . import build_module
from . build_module import build, precompute_prune, _run_graph
from . build_module import build, optimize, build_config
from .. import symbol as _symbol
from .. import graph as _graph
......
......@@ -3,10 +3,74 @@
from __future__ import absolute_import as _abs
import tvm
from . import graph_attr, graph_pass
from . import graph_attr, graph_util
from .. import graph as _graph
from .. import runtime
OPT_PASS_LEVEL = {
"SimplifyBatchNormInference": 2,
"PrecomputePrune": 2,
"OpFusion": 1
}
# List of optimization pass and level when switch on
class BuildConfig(object):
"""Configuration scope to set a build config option.
Parameters
----------
kwargs
Keyword arguments of configurations to set.
"""
current = None
defaults = {
"opt_level": 2,
}
def __init__(self, **kwargs):
self._old_scope = None
for k, _ in kwargs.items():
if k not in BuildConfig.defaults:
raise ValueError(
"invalid argument %s, candidates are %s" % (k, BuildConfig.defaults.keys()))
self._attr = kwargs
def __getattr__(self, name):
if name not in self._attr:
return BuildConfig.defaults[name]
return self._attr[name]
def __enter__(self):
# pylint: disable=protected-access
self._old_scope = BuildConfig.current
attr = BuildConfig.current._attr.copy()
attr.update(self._attr)
self._attr = attr
BuildConfig.current = self
return self
def __exit__(self, ptype, value, trace):
assert self._old_scope
BuildConfig.current = self._old_scope
BuildConfig.current = BuildConfig()
def build_config(**kwargs):
"""Configure the build behavior by setting config variables.
Parameters
----------
opt_level: int, default=2
Optimization level. See OPT_PASS_LEVEL for level of each pass.
Returns
-------
config: BuildConfig
The build configuration
"""
return BuildConfig(**kwargs)
@tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name):
f = tvm.lower(sch, inputs, name=func_name)
......@@ -19,23 +83,45 @@ def _build(funcs, target):
return tvm.build(funcs, target=target)
def optimize(graph):
"""Perform graph optimization
def _update_shape_dtype(shape, dtype, params):
"""Update shape dtype given params information"""
if not params:
return shape, dtype
shape = shape.copy()
shape.update({k : v.shape for k, v in params.items()})
if isinstance(dtype, str):
for k, v in params.items():
if v.dtype != dtype:
raise ValueError(
"%s: dtype not expected %s vs %s" % (k, dtype, v.dtype))
else:
dtype = dtype.copy()
dtype.update({k : str(v.dtype) for k, v in params.items()})
return shape, dtype
def optimize(graph, shape, dtype="float32"):
"""Perform target and parameter invariant graph optimization.
Parameters
----------
graph : Graph
The graph to be used in lowering.
The graph to be used in optimized.
Returns
-------
graph : Graph
The optimized execution graph.
The optimized graph.
"""
# pylint: disable=unused-argument
cfg = BuildConfig.current
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyBatchNormInference"])
return graph
def build(graph, target, shape, dtype="float32"):
def build(graph, target, shape, dtype="float32", params=None):
"""Build graph into runtime library.
This is the final step of graph compilation.
......@@ -54,6 +140,11 @@ def build(graph, target, shape, dtype="float32"):
dtype : str or dict of str to str
The input types to the graph
params : dict of str to NDArray
Input parameetrs to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
Returns
-------
graph : Graph
......@@ -61,20 +152,33 @@ def build(graph, target, shape, dtype="float32"):
libmod : tvm.Module
The modue that comes with the execution graph
params : dict of str to NDArray
The updated parameters of graph if params is passed.
This can be different from the params passed in.
"""
if not isinstance(target, str):
raise TypeError("require target to be str")
if not isinstance(shape, dict):
raise TypeError("require shape to be dict")
cfg = BuildConfig.current
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Apply optimization
graph = optimize(graph, shape, dtype)
# Precompute prune
if params and cfg.opt_level >= OPT_PASS_LEVEL["PrecomputePrune"]:
graph, params = precompute_prune(graph, params)
shape, dtype = _update_shape_dtype(shape, dtype, params)
# Operator Fusion and generatiom
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str")
graph._set_json_attr("opt_level", cfg.opt_level, "int")
graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod
return graph, libmod, params
def _run_graph(graph, params):
......@@ -98,9 +202,9 @@ def _run_graph(graph, params):
dtype = {k : v.dtype for k, v in params.items()}
target = "llvm"
ctx = tvm.cpu(0)
_, oshape = graph_pass.infer_shape(graph, **shape)
_, odtype = graph_pass.infer_dtype(graph, **dtype)
graph, libmod = build(graph, target, shape, dtype)
_, oshape = graph_util.infer_shape(graph, **shape)
_, odtype = graph_util.infer_dtype(graph, **dtype)
graph, libmod, _ = build(graph, target, shape, dtype)
m = runtime.create(graph, libmod, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
for k, v in params.items():
......
......@@ -6,81 +6,3 @@ Principle:
- Composable API: break graph transformation pass as segments of small transformations.
"""
from __future__ import absolute_import as _abs
import tvm
from . import graph_attr
def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape
def infer_dtype(graph, **dtype):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
def check_graph_equal(grapha, graphb):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
if err:
raise ValueError("Graph compare error: " + err)
# pylint: disable=invalid-name
"""Utility function to get information from graph."""
from __future__ import absolute_import as _abs
import tvm
from . import graph_attr
def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape
def infer_dtype(graph, **dtype):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype
_deep_compare = tvm.get_global_func("nnvm.graph.DeepCompare")
def check_graph_equal(grapha, graphb):
"""Check if two graphs have equal structure.
Parameters
----------
grapha : Graph
The first graph
graphb : Graph
The second graph
Raises
------
ValueError
ValueError is raised with error message when graph not equal
"""
err = _deep_compare(grapha, graphb)
if err:
raise ValueError("Graph compare error: " + err)
......@@ -2,6 +2,82 @@
import tvm
from tvm.contrib import rpc
class Module(object):
"""Wrapper runtime module.
This is a thin wrapper of the underlying TVM module.
you can also directly call set_input, run, and get_output
of underlying module functions
Parameters
----------
tvm_module : tvm.Module
The interal tvm module
"""
def __init__(self, tvm_module):
self.tvm_module = tvm_module
self._set_input = tvm_module["set_input"]
self._run = tvm_module["run"]
self._get_output = tvm_module["get_output"]
def set_input(self, key=None, value=None, **params):
"""Set inputs to the module via kwargs
Parameters
----------
key : int or str
The input key
value : the input value.
The input key
params : dict of str to NDArray
Additonal arguments
"""
if key:
self._set_input(key, tvm.nd.array(value))
for k, v in params.items():
self._set_input(k, tvm.nd.array(v))
return self
def run(self, **input_dict):
"""Run forward execution of the graph
Parameters
----------
input_dict: dict of str to NDArray
List of input values to be feed to
"""
if input_dict:
self.set_input(**input_dict)
self._run()
def get_output(self, index, out):
"""Get index-th output to out
Parameters
----------
index : int
The input index
out : tvm.NDArray
The output array container
"""
self._get_output(index, out)
return out
def __getitem__(self, key):
"""Get internal module function
Parameters
----------
key : str
The key to the module.
"""
return self.tvm_module[key]
def create(graph, libmod, ctx):
"""Create a runtime executor module given the graph and module.
......@@ -30,7 +106,6 @@ def create(graph, libmod, ctx):
hmod = rpc._ModuleHandle(libmod)
fcreate = ctx._rpc_sess.get_function("nnvm.runtime.remote_create")
device_type = device_type % rpc.RPC_SESS_MASK
return fcreate(json_str, hmod, device_type, device_id)
return Module(fcreate(json_str, hmod, device_type, device_id))
fcreate = tvm.get_global_func("nnvm.runtime.create")
return fcreate(json_str, libmod, device_type, device_id)
return Module(fcreate(json_str, libmod, device_type, device_id))
......@@ -41,6 +41,11 @@ DLDataType GetDLType(int type_flag) {
nnvm::Graph GraphFusePartition(nnvm::Graph g) {
// setup ref counter
const IndexedGraph& idx = g.indexed_graph();
int opt_level = 2;
if (g.attrs.count("opt_level") != 0) {
opt_level = g.MoveCopyAttr<int>("opt_level");
}
// Get attributes from the graph
const ShapeVector& shape_vec = g.GetAttr<ShapeVector>("shape");
const DTypeVector& dtype_vec = g.GetAttr<DTypeVector>("dtype");
......@@ -65,7 +70,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
// this line will realize all the outputs
ref_count[e.node_id] += 2;
}
// Pattern fo the subgraph
// Pattern for the subgraph
std::vector<TOpPattern> pattern_vec(idx.num_nodes(), kExtern);
// Whether node can be fused to parent.
std::vector<FuseRule> fuse_vec(idx.num_nodes(), FuseRule::kUknown);
......@@ -123,7 +128,7 @@ nnvm::Graph GraphFusePartition(nnvm::Graph g) {
}
pattern_vec[nid] = pt;
if (ref_count[nid] > 1) {
if (ref_count[nid] > 1 || opt_level < 1) {
fuse_vec[nid] = FuseRule::kRealize;
if (master_vec[nid] == -1) {
master_vec[nid] = nid;
......
......@@ -21,6 +21,10 @@ TVM_REGISTER_EXT_TYPE(nnvm::compiler::AttrDict);
} // namespace tvm
namespace nnvm {
DMLC_JSON_ENABLE_ANY(int, int);
} // namespace nnvm
namespace nnvm {
namespace compiler {
using tvm::Tensor;
......
......@@ -4,6 +4,7 @@ import tvm
import nnvm.symbol as sym
import nnvm.compiler
import nnvm.runtime
from nnvm.compiler.build_module import _run_graph, precompute_prune
def test_compile():
x = sym.Variable("x")
......@@ -12,8 +13,7 @@ def test_compile():
shape = (10, 128)
dtype = tvm.float32
shape_dict = {"x": shape, "y": shape}
graph, lib = nnvm.compiler.build(z, "llvm", shape_dict)
def verify(graph, lib):
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......@@ -30,6 +30,18 @@ def test_compile():
np.testing.assert_allclose(
out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict)
assert graph.index.num_nodes == 3
verify(graph, lib)
with nnvm.compiler.build_config(opt_level=0):
graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict)
# print(graph.ir())
assert graph.index.num_nodes == 4
verify(graph, lib)
def test_run():
x = sym.Variable("x")
......@@ -39,7 +51,7 @@ def test_run():
dtype = tvm.float32
nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
res = nnvm.compiler._run_graph(z, {"x": nx, "y": ny})
res = _run_graph(z, {"x": nx, "y": ny})
np.testing.assert_allclose(
res[0].asnumpy(), np.exp(nx.asnumpy() + ny.asnumpy()))
......@@ -53,11 +65,16 @@ def test_precompute_prune():
nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
params = {"x": nx}
graph, pdict = nnvm.compiler.precompute_prune(z, params)
pdict["y"] = ny
res = nnvm.compiler._run_graph(z, pdict)
graph, lib, params = nnvm.compiler.build(
z, "llvm", shape={"y": ny.shape}, params=params)
assert graph.index.num_nodes == 3
m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
params["y"] = ny
res = tvm.nd.empty(shape)
m.run(**params)
out = m.get_output(0, out=res)
np.testing.assert_allclose(
res[0].asnumpy(), nx.asnumpy() + 1 + ny.asnumpy())
res.asnumpy(), nx.asnumpy() + 1 + ny.asnumpy())
if __name__ == "__main__":
......
......@@ -2,16 +2,16 @@
import nnvm
import nnvm.compiler
from nnvm import symbol as sym
from nnvm.compiler import graph_pass, graph_attr
from nnvm.compiler import graph_util, graph_attr
def test_infer_attr():
x = sym.Variable("x")
y = x * 2
g = nnvm.graph.create(y)
ishape, oshape = graph_pass.infer_shape(g, x=(10,20))
ishape, oshape = graph_util.infer_shape(g, x=(10,20))
assert tuple(oshape[0]) == (10, 20)
itype, otype = graph_pass.infer_dtype(g, x="float32")
itype, otype = graph_util.infer_dtype(g, x="float32")
assert otype[0] == "float32"
if __name__ == "__main__":
......
......@@ -19,7 +19,7 @@ def test_rpc_executor():
tmp = util.tempdir()
lib_name = tmp.relpath("net.o")
graph, lib = nnvm.compiler.build(z, "llvm", shape_dict)
graph, lib, _ = nnvm.compiler.build(z, "llvm", shape_dict)
# save module
lib.save(lib_name)
remote = rpc.connect(host, port)
......
"""Unittest cases for simplify batch_norm"""
import nnvm
from nnvm import symbol as sym
from nnvm.compiler import graph_pass, graph_attr
from nnvm.compiler import graph_util, graph_attr
def test_simplify_batchnorm():
def simple_bn(x, gamma, beta, moving_mean, moving_var,
......@@ -40,7 +40,7 @@ def test_simplify_batchnorm():
# Some prints for debug
# print(g1.ir())
# assert graph equals as expected
graph_pass.check_graph_equal(g1, g2)
graph_util.check_graph_equal(g1, g2)
check(2, 1, 1)
check(4, 0, 3)
......
......@@ -26,7 +26,7 @@ def test_relu():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......@@ -48,7 +48,7 @@ def test_exp():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......@@ -70,7 +70,8 @@ def test_log():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......@@ -92,7 +93,8 @@ def test_tanh():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......@@ -114,7 +116,7 @@ def test_sigmoid():
dtype = "float32"
dshape = (1, 3, 32, 32)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......@@ -136,7 +138,8 @@ def test_softmax():
dtype = "float32"
dshape = (10, 1000)
oshape = dshape
graph, lib = nnvm.compiler.build(y, default_target(), {"x": dshape})
with nnvm.compiler.build_config(opt_level=1):
graph, lib, _ = nnvm.compiler.build(y, default_target(), {"x": dshape})
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......
......@@ -29,7 +29,7 @@ def test_conv2d():
kshape = (10, 3, 3, 3)
oshape = (1, 10, 18, 18)
shape_dict = {"x": dshape}
graph, lib = nnvm.compiler.build(y, default_target(), shape_dict)
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......@@ -57,7 +57,7 @@ def test_grouped_conv2d():
kshape = (32, 1, 3, 3)
oshape = (1, 32, 18, 18)
shape_dict = {"x": dshape}
graph, lib = nnvm.compiler.build(y, default_target(), shape_dict)
graph, lib, _ = nnvm.compiler.build(y, default_target(), shape_dict)
m = nnvm.runtime.create(graph, lib, default_ctx())
# get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment