Commit 7c95535c by Tianqi Chen

[PASS] PrecomputePrune, add testcase (#14)

* [PASS] PrecomputePrune, add testcase

* update comment
parent d27c11e0
...@@ -30,7 +30,7 @@ ifneq ($(ADD_CFLAGS), NONE) ...@@ -30,7 +30,7 @@ ifneq ($(ADD_CFLAGS), NONE)
endif endif
ifneq ($(ADD_LDFLAGS), NONE) ifneq ($(ADD_LDFLAGS), NONE)
LFFLAGS += $(ADD_LDFLAGS) LDFLAGS += $(ADD_LDFLAGS)
endif endif
# plugin # plugin
...@@ -46,6 +46,7 @@ ifeq ($(UNAME_S), Darwin) ...@@ -46,6 +46,7 @@ ifeq ($(UNAME_S), Darwin)
SHARED_LIBRARY_SUFFIX := dylib SHARED_LIBRARY_SUFFIX := dylib
WHOLE_ARCH= -all_load WHOLE_ARCH= -all_load
NO_WHOLE_ARCH= -noall_load NO_WHOLE_ARCH= -noall_load
LDFLAGS += -undefined dynamic_lookup
else else
SHARED_LIBRARY_SUFFIX := so SHARED_LIBRARY_SUFFIX := so
WHOLE_ARCH= --whole-archive WHOLE_ARCH= --whole-archive
......
...@@ -4,7 +4,7 @@ from __future__ import absolute_import ...@@ -4,7 +4,7 @@ from __future__ import absolute_import
import tvm import tvm
from . import build_module from . import build_module
from . build_module import build from . build_module import build, precompute_prune, _run_graph
from .. import symbol as _symbol from .. import symbol as _symbol
from .. import graph as _graph from .. import graph as _graph
......
...@@ -3,8 +3,9 @@ ...@@ -3,8 +3,9 @@
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
import tvm import tvm
from . import graph_attr from . import graph_attr, graph_pass
from .. import graph as _graph from .. import graph as _graph
from .. import runtime
@tvm.register_func("nnvm.compiler.lower") @tvm.register_func("nnvm.compiler.lower")
def _lower(sch, inputs, func_name): def _lower(sch, inputs, func_name):
...@@ -18,9 +19,6 @@ def _build(funcs, target): ...@@ -18,9 +19,6 @@ def _build(funcs, target):
return tvm.build(funcs, target=target) return tvm.build(funcs, target=target)
_move_module = tvm.get_global_func("nnvm.compiler._move_module")
def optimize(graph): def optimize(graph):
"""Perform graph optimization """Perform graph optimization
...@@ -70,10 +68,83 @@ def build(graph, target, shape, dtype="float32"): ...@@ -70,10 +68,83 @@ def build(graph, target, shape, dtype="float32"):
raise TypeError("require shape to be dict") raise TypeError("require shape to be dict")
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph) graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
graph = graph_attr.set_shape(graph, shape) graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph_attr.set_dtype(graph, dtype) graph = graph_attr.set_dtype_inputs(graph, dtype)
graph._set_json_attr("target", target, "str") graph._set_json_attr("target", target, "str")
graph = graph.apply("InferShape").apply("InferType") graph = graph.apply("InferShape").apply("InferType")
graph = graph.apply("GraphFusePartition").apply("GraphFuse") graph = graph.apply("GraphFusePartition").apply("GraphFuse")
libmod = _move_module(graph) libmod = graph_attr._move_out_module(graph, "module")
return graph, libmod return graph, libmod
def _run_graph(graph, params):
"""Helper utility to build and run and get outputs, only use cpu mode.
Parameters
----------
graph : Graph
The graph to be executed.
params: dict of str to ndarray
The parameter dictionary.
Returns
-------
out_dict: dict of str to tvm.NDArray
The output dictionaries.
"""
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
shape = {k : v.shape for k, v in params.items()}
dtype = {k : v.dtype for k, v in params.items()}
target = "llvm"
ctx = tvm.cpu(0)
_, oshape = graph_pass.infer_shape(graph, **shape)
_, odtype = graph_pass.infer_dtype(graph, **dtype)
graph, libmod = build(graph, target, shape, dtype)
m = runtime.create(graph, libmod, ctx)
set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
for k, v in params.items():
set_input(k, tvm.nd.array(v))
run()
out_data = []
for i, kv in enumerate(zip(oshape, odtype)):
shape, dtype = kv
arr = tvm.nd.empty(shape, dtype, ctx)
get_output(i, arr)
out_data.append(arr)
return out_data
def precompute_prune(graph, params):
"""Precompute the part of graph that can be pre-computed.
This will create a new graph that only contains the ops
that need to be computed depending on input as well as
updated version of param dict that pre-computes some of
intermediate results.
Parameters
----------
graph : Graph
The input graph
params : dict of str -> tvm.NDArray
The parameter dictionary of the graph
Returns
-------
pruned_graph : Graph
The pruned graph
new_params : dict of str-> tvm.NDArray
The updated dictionary of parameters.
"""
graph = graph if isinstance(graph, _graph.Graph) else _graph.create(graph)
graph._set_json_attr("param_name_list", list(params.keys()), "list_str")
graph = graph.apply("PrecomputePrune")
pre_graph = graph_attr._move_out_graph(graph, "precompute_graph")
if not pre_graph.symbol.list_output_names():
return graph, params
out_names = pre_graph.json_attr("output_names")
out_arrs = _run_graph(pre_graph, params)
return graph, dict(zip(out_names, out_arrs))
# pylint: disable=invalid-name
"""Utilities to access graph attributes""" """Utilities to access graph attributes"""
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
def set_shape(g, shape): import tvm
"""Set the shape of graph nodes in the graph attribute.
def set_shape_inputs(g, shape):
"""Set the shape of input graph nodes in the graph attribute.
Parameters Parameters
---------- ----------
...@@ -17,20 +20,24 @@ def set_shape(g, shape): ...@@ -17,20 +20,24 @@ def set_shape(g, shape):
g : Graph g : Graph
The updated graph with updated shape. The updated graph with updated shape.
""" """
index = g.index list_shape = [
list_shape = [[]] * index.num_node_entries shape.get(name, ()) for name in g.index.input_names]
for k, v in shape.items(): g._set_json_attr("shape_inputs", list_shape, 'list_shape')
list_shape[index.entry_id(k)] = v
g._set_json_attr("shape", list_shape, 'list_shape')
return g return g
DTYPE_DICT = { DTYPE_TO_TCODE = {
"default": -1,
"float32": 0 "float32": 0
} }
def set_dtype(g, dtype): TCODE_TO_DTYPE = {
"""Set the dtype of graph nodes -1: None,
0: "float32"
}
def set_dtype_inputs(g, dtype):
"""Set the dtype inputs of graph nodes
Parameters Parameters
---------- ----------
...@@ -45,12 +52,37 @@ def set_dtype(g, dtype): ...@@ -45,12 +52,37 @@ def set_dtype(g, dtype):
g : Graph g : Graph
The updated graph with updated dtype. The updated graph with updated dtype.
""" """
index = g.index
if isinstance(dtype, dict): if isinstance(dtype, dict):
list_dtype = [-1] * index.num_node_entries list_dtype = [
for k, v in dtype.items(): DTYPE_TO_TCODE[dtype.get(name, "default")]
list_dtype[index.entry_id(k)] = DTYPE_DICT[v] for name in g.index.input_names]
else: else:
list_dtype = [DTYPE_DICT[dtype]] * index.num_node_entries list_dtype = [DTYPE_TO_TCODE[dtype]] * len(g.index.input_names)
g._set_json_attr("dtype", list_dtype, "list_int") g._set_json_attr("dtype_inputs", list_dtype, "list_int")
return g
def set_layout_inputs(g, layout):
"""Set the layout inputs of graph nodes
Parameters
----------
g : Graph
The input graph
layout : dict of str to str or str
The input layout
Returns
-------
g : Graph
The updated graph with updated dtype.
"""
list_shape = [
layout.get(name, "default") for name in g.index.input_names]
g._set_json_attr("layout_inputs", list_shape, 'list_str')
return g return g
_move_out_module = tvm.get_global_func("nnvm.graph_attr._move_module")
_move_out_graph = tvm.get_global_func("nnvm.graph_attr._move_graph")
# pylint: disable=invalid-name
"""Namespace of graph pass. """Namespace of graph pass.
Principle: Principle:
...@@ -5,3 +6,57 @@ Principle: ...@@ -5,3 +6,57 @@ Principle:
- Composable API: break graph transformation pass as segments of small transformations. - Composable API: break graph transformation pass as segments of small transformations.
""" """
from __future__ import absolute_import as _abs from __future__ import absolute_import as _abs
from . import graph_attr
def infer_shape(graph, **shape):
"""Infer the shape given the shape of inputs.
Parameters
----------
graph : Graph
The graph to perform shape inference from
Returns
-------
in_shape : list of tuple
Shape of inputs
out_shape: list of tuple
Shape of outputs
"""
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
shape = graph.json_attr("shape")
index = graph.index
input_shape = [shape[index.entry_id(x)] for x in index.input_names]
output_shape = [shape[index.entry_id(x)] for x in index.output_entries]
return input_shape, output_shape
def infer_dtype(graph, **dtype):
"""Infer the type given the typeS of inputs.
Parameters
----------
graph : Graph
The graph to perform type inference from
Returns
-------
in_dtype : list of tuple
Dtype of inputs
out_dtype: list of tuple
Dtype of outputs
"""
graph = graph_attr.set_dtype_inputs(graph, dtype)
graph = graph.apply("InferType")
dtype = graph.json_attr("dtype")
index = graph.index
input_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.input_names]
output_dtype = [graph_attr.TCODE_TO_DTYPE[dtype[index.entry_id(x)]]
for x in index.output_entries]
return input_dtype, output_dtype
...@@ -24,6 +24,8 @@ class GraphIndex(object): ...@@ -24,6 +24,8 @@ class GraphIndex(object):
self.nodes = jgraph["nodes"] self.nodes = jgraph["nodes"]
self.entry_ptr = jgraph["node_row_ptr"] self.entry_ptr = jgraph["node_row_ptr"]
self._name2nodeid = {n["name"]: i for i, n in enumerate(self.nodes)} self._name2nodeid = {n["name"]: i for i, n in enumerate(self.nodes)}
self.input_names = graph.symbol.list_input_names()
self.output_entries = jgraph["heads"]
@property @property
def num_nodes(self): def num_nodes(self):
...@@ -66,6 +68,10 @@ class GraphIndex(object): ...@@ -66,6 +68,10 @@ class GraphIndex(object):
index : int index : int
The entry index The entry index
""" """
if isinstance(key, (list, tuple)):
if len(key) != 3:
raise ValueError("Expect entry index to be tuple of 3 elems")
key, value_index, _ = key
idx = self.node_id(key) if isinstance(key, str) else key idx = self.node_id(key) if isinstance(key, str) else key
assert value_index < self.entry_ptr[idx + 1] assert value_index < self.entry_ptr[idx + 1]
return self.entry_ptr[idx] + value_index return self.entry_ptr[idx] + value_index
......
...@@ -68,6 +68,21 @@ class AttrDict(object): ...@@ -68,6 +68,21 @@ class AttrDict(object):
""" """
return int(self[key]) return int(self[key])
def get_float(self, key):
"""Get float from attr dict
Parameters
----------
key : str
The attr key
Returns
-------
value : float
The result value
"""
return float(self[key])
def get_bool(self, key): def get_bool(self, key):
"""Get bool from attr dict """Get bool from attr dict
......
...@@ -17,6 +17,17 @@ def _schedule_broadcast(_, outs, target): ...@@ -17,6 +17,17 @@ def _schedule_broadcast(_, outs, target):
tvm.schedule.AutoInlineInjective(s) tvm.schedule.AutoInlineInjective(s)
return s return s
def _compute_binary_scalar(f):
"""auxiliary function"""
@tvm.tag_scope("ewise")
def _compute(attrs, x):
x = x[0]
scalar = attrs.get_float("scalar")
scalar = tvm.const(scalar, x.dtype)
return tvm.compute(x.shape, lambda *i: f(x(*i), scalar))
return _compute
_fschedule_broadcast = tvm.convert(_schedule_broadcast) _fschedule_broadcast = tvm.convert(_schedule_broadcast)
# exp # exp
...@@ -25,6 +36,12 @@ reg.register_compute("exp", ...@@ -25,6 +36,12 @@ reg.register_compute("exp",
reg.register_pattern("exp", OpPattern.ELEM_WISE) reg.register_pattern("exp", OpPattern.ELEM_WISE)
reg.register_schedule("exp", _fschedule_broadcast) reg.register_schedule("exp", _fschedule_broadcast)
# add scalar
reg.register_compute("__add_scalar__",
_compute_binary_scalar(lambda x, y: x + y))
reg.register_pattern("__add_scalar__", OpPattern.ELEM_WISE)
reg.register_schedule("__add_scalar__", _fschedule_broadcast)
# broadcast_add # broadcast_add
reg.register_compute("broadcast_add", reg.register_compute("broadcast_add",
lambda _, x: topi.broadcast_add(x[0], x[1])) lambda _, x: topi.broadcast_add(x[0], x[1]))
......
...@@ -104,5 +104,19 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern") ...@@ -104,5 +104,19 @@ TVM_REGISTER_GLOBAL("nnvm._register_pattern")
Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]); Op& op = ::dmlc::Registry<nnvm::Op>::Get()->__REGISTER_OR_GET__(args[0]);
op.set_attr<TOpPattern>("TOpPattern", args[1].operator int(), args[2]); op.set_attr<TOpPattern>("TOpPattern", args[1].operator int(), args[2]);
}); });
TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<tvm::runtime::Module>(args[1]);
});
TVM_REGISTER_GLOBAL("nnvm.graph_attr._move_graph")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<nnvm::Graph>(args[1]);
});
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -381,13 +381,5 @@ nnvm::Graph GraphFuse(nnvm::Graph g) { ...@@ -381,13 +381,5 @@ nnvm::Graph GraphFuse(nnvm::Graph g) {
NNVM_REGISTER_PASS(GraphFuse) NNVM_REGISTER_PASS(GraphFuse)
.set_body(GraphFuse); .set_body(GraphFuse);
TVM_REGISTER_GLOBAL("nnvm.compiler._move_module")
.set_body([](TVMArgs args, TVMRetValue *rv) {
const nnvm::Graph& g = args[0].AsExtension<Graph>();
*rv = const_cast<nnvm::Graph*>(&g)->
MoveCopyAttr<tvm::runtime::Module>("module");
});
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -44,7 +44,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) { ...@@ -44,7 +44,7 @@ nnvm::Graph LayoutTransform(nnvm::Graph src) {
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape"); const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
const std::vector<TLayoutInfo>& input_layouts = const std::vector<TLayoutInfo>& input_layouts =
src.GetAttr<std::vector<TLayoutInfo> >("layout"); src.GetAttr<std::vector<TLayoutInfo> >("layout_inputs");
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout()); std::vector<TLayoutInfo> produce_vec(idx.num_node_entries(), GetDefaultLayout());
......
/*! /*!
* Copyright (c) 2017 by Contributors * Copyright (c) 2017 by Contributors
* \file prune_graph.cc * \file precompute_prune.cc
* \brief Prune the graph to do constant folding. * \brief Split the graph into a pre-compute graph and a execution graph.
* *
* This pass breaks the graph into pre-compute graph * The pre-compute graph outputs parameters that can be taken
* and the execution graph. * by execution graph during execution phase.
*/ */
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <nnvm/op_attr_types.h> #include <nnvm/op_attr_types.h>
...@@ -16,11 +16,15 @@ ...@@ -16,11 +16,15 @@
namespace nnvm { namespace nnvm {
namespace compiler { namespace compiler {
nnvm::Graph PruneGraph(nnvm::Graph src) { nnvm::Graph PrecomputePrune(nnvm::Graph src) {
const auto& params = src.GetAttr<std::unordered_set<std::string> >("params"); const auto& plist
= src.GetAttr<std::vector<std::string> >("param_name_list");
std::unordered_set<std::string> params(plist.begin(), plist.end());
std::unordered_set<nnvm::Node*> pruned; std::unordered_set<nnvm::Node*> pruned;
nnvm::NodeEntryMap<nnvm::NodePtr> entry_var; nnvm::NodeEntryMap<nnvm::NodePtr> entry_var;
std::unordered_set<std::string> unique_name;
DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) { DFSVisit(src.outputs, [&](const nnvm::NodePtr& n) {
bool can_be_pruned = true; bool can_be_pruned = true;
if (n->is_variable()) { if (n->is_variable()) {
...@@ -45,7 +49,12 @@ nnvm::Graph PruneGraph(nnvm::Graph src) { ...@@ -45,7 +49,12 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
nnvm::NodePtr var = nnvm::Node::Create(); nnvm::NodePtr var = nnvm::Node::Create();
var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index); var->attrs.name = e.node->attrs.name + "_output" + std::to_string(e.index);
entry_var.emplace(e, var); entry_var.emplace(e, var);
CHECK(!unique_name.count(var->attrs.name));
unique_name.insert(var->attrs.name);
} }
// TODO(ziheng): this pass now mutates the original graph structure
// This might not be a good thing, change to copy the structure instead
//
e = nnvm::NodeEntry{entry_var.at(e), 0, 0}; e = nnvm::NodeEntry{entry_var.at(e), 0, 0};
} }
} }
...@@ -56,21 +65,21 @@ nnvm::Graph PruneGraph(nnvm::Graph src) { ...@@ -56,21 +65,21 @@ nnvm::Graph PruneGraph(nnvm::Graph src) {
pre_graph.outputs.reserve(entry_var.size()); pre_graph.outputs.reserve(entry_var.size());
std::vector<std::string> output_names; std::vector<std::string> output_names;
output_names.reserve(entry_var.size()); output_names.reserve(entry_var.size());
for (auto kv : entry_var) { for (auto kv : entry_var) {
if (kv.first.node->is_variable()) continue; if (kv.first.node->is_variable()) continue;
pre_graph.outputs.emplace_back(kv.first); pre_graph.outputs.emplace_back(kv.first);
output_names.emplace_back(kv.second->attrs.name); output_names.emplace_back(kv.second->attrs.name);
} }
// new parameter list
pre_graph.attrs["pruned_params"] = pre_graph.attrs["output_names"] =
std::make_shared<dmlc::any>(std::move(output_names)); std::make_shared<dmlc::any>(std::move(output_names));
src.attrs["pre_graph"] = src.attrs["precompute_graph"] =
std::make_shared<dmlc::any>(std::move(pre_graph)); std::make_shared<dmlc::any>(std::move(pre_graph));
return src; return src;
} }
NNVM_REGISTER_PASS(PruneGraph) NNVM_REGISTER_PASS(PrecomputePrune)
.set_body(PruneGraph); .set_body(PrecomputePrune);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -17,8 +17,8 @@ def test_compile(): ...@@ -17,8 +17,8 @@ def test_compile():
m = nnvm.runtime.create(graph, lib, tvm.cpu(0)) m = nnvm.runtime.create(graph, lib, tvm.cpu(0))
# get member functions # get member functions
set_input, run, get_output = m["set_input"], m["run"], m["get_output"] set_input, run, get_output = m["set_input"], m["run"], m["get_output"]
na = tvm.nd.array(np.ones(shape).astype(dtype)) na = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
nb = tvm.nd.array(np.ones(shape).astype(dtype)) nb = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
# set inputs # set inputs
set_input("x", na) set_input("x", na)
set_input("y", nb) set_input("y", nb)
...@@ -30,5 +30,37 @@ def test_compile(): ...@@ -30,5 +30,37 @@ def test_compile():
np.testing.assert_allclose( np.testing.assert_allclose(
out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy())) out.asnumpy(), np.exp(na.asnumpy() + nb.asnumpy()))
def test_run():
x = sym.Variable("x")
y = sym.Variable("y")
z = sym.exp(y + x)
shape = (10, 10)
dtype = tvm.float32
nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
res = nnvm.compiler._run_graph(z, {"x": nx, "y": ny})
np.testing.assert_allclose(
res[0].asnumpy(), np.exp(nx.asnumpy() + ny.asnumpy()))
def test_precompute_prune():
x = sym.Variable("x") + 1
y = sym.Variable("y")
z = y + x
shape = (10, 10)
dtype = tvm.float32
nx = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
ny = tvm.nd.array(np.random.uniform(size=shape).astype(dtype))
params = {"x": nx}
graph, pdict = nnvm.compiler.precompute_prune(z, params)
pdict["y"] = ny
res = nnvm.compiler._run_graph(z, pdict)
np.testing.assert_allclose(
res[0].asnumpy(), nx.asnumpy() + 1 + ny.asnumpy())
if __name__ == "__main__": if __name__ == "__main__":
test_compile() test_compile()
test_run()
test_precompute_prune()
"""Unittest cases for graph pass"""
import nnvm
import nnvm.compiler
from nnvm.compiler import graph_pass
def test_infer_attr():
x = nnvm.symbol.Variable("x")
y = x * 2
g = nnvm.graph.create(y)
ishape, oshape = graph_pass.infer_shape(g, x=(10,20))
assert tuple(oshape[0]) == (10, 20)
itype, otype = graph_pass.infer_dtype(g, x="float32")
assert otype[0] == "float32"
if __name__ == "__main__":
test_infer_attr()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment