Commit 40bc10f3 by Tianqi Chen

[PASS] SimplifyBatchNorm->SimplifyInference, remove dropout (#24)

parent 215693df
......@@ -8,7 +8,7 @@ from .. import graph as _graph
from .. import runtime
OPT_PASS_LEVEL = {
"SimplifyBatchNormInference": 2,
"SimplifyInference": 2,
"PrecomputePrune": 2,
"OpFusion": 1
}
......@@ -115,12 +115,9 @@ def optimize(graph, shape, dtype="float32"):
"""
# pylint: disable=unused-argument
cfg = BuildConfig.current
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply("InferShape")
if graph.json_attr("shape_num_unknown_nodes"):
raise ValueError("InferShape fails..")
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyBatchNormInference"]:
graph = graph.apply("SimplifyBatchNormInference")
if cfg.opt_level >= OPT_PASS_LEVEL["SimplifyInference"]:
graph = graph_attr.set_shape_inputs(graph, shape)
graph = graph.apply(["InferShape", "SimplifyInference"])
return graph
......
......@@ -44,6 +44,11 @@ def _compute_binary(f):
_fschedule_broadcast = tvm.convert(_schedule_broadcast)
# copy
reg.register_compute("copy", _compute_unary(topi.identity))
reg.register_pattern("copy", OpPattern.ELEM_WISE)
reg.register_schedule("copy", _fschedule_broadcast)
# exp
reg.register_compute("exp", _compute_unary(topi.exp))
reg.register_pattern("exp", OpPattern.ELEM_WISE)
......
......@@ -22,6 +22,7 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
nnvm::NodeEntry moving_mean,
nnvm::NodeEntry moving_var,
TShape dshape) {
CHECK_NE(dshape.ndim(), 0);
CHECK(attrs.op);
static const Op* bn_op = Op::Get("batch_norm");
CHECK(attrs.op == bn_op);
......@@ -76,13 +77,14 @@ BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
return {out, undef, undef};
}
Graph SimplifyBatchNormInference(nnvm::Graph src) {
Graph SimplifyInference(nnvm::Graph src) {
// Get attributes from the graph
const IndexedGraph& idx = src.indexed_graph();
const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
auto transform = [&](uint32_t nid, const Node* n, std::vector<NodeEntry>* ret) {
if (n->is_variable()) return false;
static const Op* bn_op = Op::Get("batch_norm");
static const Op* dropout_op = Op::Get("dropout");
if (n->op() == bn_op) {
*ret = BatchNormToInferUnpack(
n->attrs,
......@@ -93,6 +95,10 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
n->inputs[4],
shape_vec[idx.entry_id(nid, 0)]);
return true;
} else if (n->op() == dropout_op) {
NodeEntry undef = MakeNode("__undef__", "undef", {});
*ret = {n->inputs[0], undef};
return true;
} else {
return false;
}
......@@ -100,8 +106,8 @@ Graph SimplifyBatchNormInference(nnvm::Graph src) {
return GraphTransform(src, transform);
}
NNVM_REGISTER_PASS(SimplifyBatchNormInference)
.set_body(SimplifyBatchNormInference);
NNVM_REGISTER_PASS(SimplifyInference)
.set_body(SimplifyInference);
} // namespace compiler
} // namespace nnvm
......@@ -30,12 +30,13 @@ def test_simplify_batchnorm():
for i in range(nstep):
y1 = sym.batch_norm(
y1 + 1, gamma, beta, moving_mean, moving_var, epsilon=eps, axis=axis)
y1 = sym.dropout(y1)
y2 = simple_bn(y2 + 1, gamma, beta, moving_mean, moving_var,
epsilon=eps, axis=axis, shape=ishape["x"])
g = nnvm.graph.create(y1)
g2 = nnvm.graph.create(y2)
graph_attr.set_shape_inputs(g, ishape)
g1 = g.apply("InferShape").apply("SimplifyBatchNormInference")
g1 = g.apply("InferShape").apply("SimplifyInference")
# Some prints for debug
# print(g1.ir())
# assert graph equals as expected
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment