Commit 40bc10f3 by Tianqi Chen

[PASS] SimplifyBatchNorm->SimplifyInference, remove dropout (#24)

parent 215693df
...@@ -8,7 +8,7 @@ from .. import graph as _graph ...@@ -8,7 +8,7 @@ from .. import graph as _graph
from .. import runtime from .. import runtime
OPT_PASS_LEVEL = { OPT_PASS_LEVEL = {
"SimplifyBatchNormInference": 2, "SimplifyInference": 2,
"PrecomputePrune": 2, "PrecomputePrune": 2,
"OpFusion": 1 "OpFusion": 1
} }
...@@ -115,12 +115,9 @@ def optimize(graph, shape, dtype="float32"): ...@@ -115,12 +115,9 @@ def optimize(graph, shape, dtype="float32"):
""" """
# pylint: disable=unused-argument # pylint: disable=unused-argument
cfg = BuildConfig.current cfg = BuildConfig.current
graph = graph_attr.set_shape_inputs(graph, shape) if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]:
graph = graph.apply("InferShape") graph = graph_attr.set_shape_inputs(graph, shape)
if graph.json_attr("shape_num_unknown_nodes"): graph = graph.apply(["InferShape", "SimplifyInference"])
raise ValueError("InferShape fails..")
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
graph = graph.apply("SimplifyBatchNormInference")
return graph return graph
......
...@@ -44,6 +44,11 @@ def _compute_binary(f): ...@@ -44,6 +44,11 @@ def _compute_binary(f):
_fschedule_broadcast = tvm.convert(_schedule_broadcast) _fschedule_broadcast = tvm.convert(_schedule_broadcast)
# copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEM_WISE)
reg.register_schedule("copy", _fschedule_broadcast)
# exp # exp
reg.register_compute("exp", _compute_unary(topi.exp)) reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEM_WISE) reg.register_pattern("exp", OpPattern.ELEM_WISE)
......
...@@ -22,6 +22,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -22,6 +22,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm::NodeEntry moving_mean, nnvm::NodeEntry moving_mean,
nnvm::NodeEntry moving_var, nnvm::NodeEntry moving_var,
TShape dshape) { TShape dshape) {
CHECK_NE(dshape.ndim(), 0);
CHECK(attrs.op); CHECK(attrs.op);
static const Op* bn_op = Op::Get("batch_norm"); static const Op* bn_op = Op::Get("batch_norm");
CHECK(attrs.op == bn_op); CHECK(attrs.op == bn_op);
...@@ -76,13 +77,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs, ...@@ -76,13 +77,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
return {out, undef, undef}; return {out, undef, undef};
} }
Graph SimplifyBatchNormInference(nnvm::Graph src) { Graph SimplifyInference(nnvm::Graph src) {
// Get attributes from the graph // Get attributes from the graph
const IndexedGraph& idx = src.indexed_graph(); const IndexedGraph& idx = src.indexed_graph();
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape"); const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) { auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) {
if (n->is_variable()) return false; if (n->is_variable()) return false;
static const Op* bn_op = Op::Get("batch_norm"); static const Op* bn_op = Op::Get("batch_norm");
static const Op* dropout_op = Op::Get("dropout");
if (n->op() == bn_op) { if (n->op() == bn_op) {
*ret = BatchNormToInferUnpack( *ret = BatchNormToInferUnpack(
n->attrs, n->attrs,
...@@ -93,6 +95,10 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) { ...@@ -93,6 +95,10 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
n->inputs[4], n->inputs[4],
shape_vec[idx.entry_id(nid, 0)]); shape_vec[idx.entry_id(nid, 0)]);
return true; return true;
} else if (n->op() == dropout_op) {
NodeEntry undef = MakeNode("__undef__", "undef", {});
*ret = {n->inputs[0], undef};
return true;
} else { } else {
return false; return false;
} }
...@@ -100,8 +106,8 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) { ...@@ -100,8 +106,8 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
return GraphTransform(src, transform); return GraphTransform(src, transform);
} }
NNVM_REGISTER_PASS(SimplifyBatchNormInference) NNVM_REGISTER_PASS(SimplifyInference)
.set_body(SimplifyBatchNormInference); .set_body(SimplifyInference);
} // namespace compiler } // namespace compiler
} // namespace nnvm } // namespace nnvm
...@@ -30,12 +30,13 @@ def test_simplify_batchnorm(): ...@@ -30,12 +30,13 @@ def test_simplify_batchnorm():
for i in range(nstep): for i in range(nstep):
y1 = sym.batch_norm( y1 = sym.batch_norm(
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis) y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = sym.dropout(y1)
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var, y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ishape["x"]) epsilon=eps, axis=axis, shape=ishape["x"])
g = nnvm.graph.create(y1) g = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2) g2 = nnvm.graph.create(y2)
graph_attr.set_shape_inputs(g, ishape) graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference") g1 = g.apply("InferShape").apply("SimplifyInference")
# Some prints for debug # Some prints for debug
# print(g1.ir()) # print(g1.ir())
# assert graph equals as expected # assert graph equals as expected
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment