simplify_inference.cc 3.41 KB
Newer Older
1 2
/*!
 * Copyright (c) 2017 by Contributors
3
 * \file simplify_inference.cc
4 5 6 7 8 9 10 11
 * \author Ziheng Jiang
*/
#include <nnvm/graph.h>
#include <nnvm/op_attr_types.h>
#include <nnvm/graph_attr_types.h>
#include <nnvm/pass.h>
#include <nnvm/compiler/op_attr_types.h>
#include <nnvm/top/nn.h>
12 13
#include "graph_transform.h"
#include "pattern_util.h"
14 15 16 17 18 19 20 21 22 23 24

namespace nnvm {
namespace compiler {

std::vector<NodeEntry>
BatchNormToInferUnpack(const nnvm::NodeAttrs& attrs,
                       nnvm::NodeEntry data,
                       nnvm::NodeEntry gamma,
                       nnvm::NodeEntry beta,
                       nnvm::NodeEntry moving_mean,
                       nnvm::NodeEntry moving_var,
25 26
                       TShape dshape,
                       TShape bshape) {
27
  CHECK_NE(dshape.ndim(), 0);
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62
  CHECK(attrs.op);
  static const  Op* bn_op = Op::Get("batch_norm");
  CHECK(attrs.op == bn_op);
  const auto& param = nnvm::get<top::BatchNormParam>(attrs.parsed);
  std::string bn_name = attrs.name;

  // transform batch_norm(data) to scale * data + shift
  NodeEntry var_add_eps = MakeNode(
      "__add_scalar__", bn_name + "_add_eps",
      {moving_var}, {{"scalar", std::to_string(param.epsilon)}});

  NodeEntry sqrt = MakeNode(
      "sqrt", bn_name + "_sqrt", {var_add_eps});

  NodeEntry scale = MakeNode(
      "__rdiv_scalar__", bn_name + "_div",
      {sqrt}, {{"scalar", "1"}});

  if (param.scale) {
    scale = MakeNode(
        "elemwise_mul", bn_name + "_gamma_mul_div",
        {scale, gamma});
  }

  NodeEntry neg_mean = MakeNode(
      "negative", bn_name + "_neg_mean", {moving_mean});

  NodeEntry shift = MakeNode(
      "elemwise_mul", bn_name + "_neg_mean_mul_a",
      {neg_mean, scale});

  if (param.center) {
    shift = MakeNode(
        "elemwise_add", bn_name + "_add_beta", {shift, beta});
  }
63
  int axis = param.axis;
64 65 66
  scale = ExpandBiasToMatchAxis(scale, dshape.ndim()-bshape.ndim()+1, 1, axis);
  shift = ExpandBiasToMatchAxis(shift, dshape.ndim()-bshape.ndim()+1, 1, axis);

67 68 69 70
  NodeEntry out = MakeNode("broadcast_mul", bn_name + "_a_mul_data",
                           {data, scale});
  out = MakeNode("broadcast_add", bn_name + "_out",
                 {out, shift});
71
  // It is invalid to ref the other values of BN after inference transform.
72 73 74 75
  NodeEntry undef = MakeNode("__undef__", "undef", {});
  return {out, undef, undef};
}

76
Graph SimplifyInference(nnvm::Graph src) {
77 78 79
  // Get attributes from the graph
  const IndexedGraph& idx = src.indexed_graph();
  const ShapeVector& shape_vec = src.GetAttr<ShapeVector>("shape");
80
  auto transform = [&](uint32_t nid, const NodePtr& n, std::vector<NodeEntry>* ret) {
81 82
    if (n->is_variable()) return false;
    static const Op* bn_op = Op::Get("batch_norm");
83
    static const Op* dropout_op = Op::Get("dropout");
84 85 86 87 88 89 90 91
    if (n->op() == bn_op) {
      *ret = BatchNormToInferUnpack(
          n->attrs,
          n->inputs[0],
          n->inputs[1],
          n->inputs[2],
          n->inputs[3],
          n->inputs[4],
92 93
          shape_vec[idx.entry_id(nid, 0)],
          shape_vec[idx.entry_id(nid, 1)]);
94
      return true;
95 96 97 98
    } else if (n->op() == dropout_op) {
      NodeEntry undef = MakeNode("__undef__", "undef", {});
      *ret = {n->inputs[0], undef};
      return true;
99 100 101 102 103 104 105
    } else {
      return false;
    }
  };
  return GraphTransform(src, transform);
}

106
NNVM_REGISTER_PASS(SimplifyInference)
107 108
.set_body(SimplifyInference)
.set_change_graph(true);
109 110 111

}  // namespace compiler
}  // namespace nnvm