gradient.cc 9.99 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
/*!
 *  Copyright (c) 2016 by Contributors
 * \file gradients.cc
 * \brief Passes that takes gradient of the graph
 * This code code was modified based on mxnet codebase by Min Lin
 */
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
#include <algorithm>
#include <functional>

namespace nnvm {
namespace pass {
namespace {

// default aggregate gradient function
36
// require operator zeros and elemwise_sum to be presented.
37 38 39 40 41
NodeEntry DefaultAggregateGradient(std::vector<NodeEntry>&& v) {
  if (v.size() == 1) {
    return std::move(v[0]);
  } else if (v.size() == 0) {
    NodePtr zero_node = Node::Create();
42 43 44
    zero_node->attrs.op = Op::Get("zeros");
    zero_node->attrs.name = "zero_grad";
    zero_node->attrs.op->attr_parser(&(zero_node->attrs));
45 46 47
    return NodeEntry{zero_node, 0, 0};
  } else {
    NodePtr sum_node = Node::Create();
Yao Wang committed
48
    sum_node->attrs.op = Op::Get("elemwise_sum");
49
    sum_node->inputs = std::move(v);
50 51 52
    sum_node->attrs.name = "grad_sum";
    sum_node->attrs.dict["num_args"] = std::to_string(sum_node->inputs.size());
    sum_node->attrs.op->attr_parser(&(sum_node->attrs));
53 54 55 56
    return NodeEntry{sum_node, 0, 0};
  }
}

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
                      const std::vector<const Op*>& zero_ops) {
  if (!grads.size() || !zero_ops.size()) return false;
  for (const auto& g : grads) {
    bool found = false;
    for (const auto& op : zero_ops) {
      if (g.node->op() == op) {
        found = true;
        break;
      }
    }
    if (!found) return false;
  }
  return true;
}

73 74
// helper entry
struct GradEntry {
75 76 77
#ifdef _MSC_VER
  NodeEntry sum = NodeEntry{nullptr, 0, 0};
#else
78
  NodeEntry sum{nullptr, 0, 0};
79
#endif
80
  std::vector<NodeEntry> grads;
81
  bool need_attr_hint{true};
82 83 84 85 86
};

Graph Gradient(Graph src) {
  using nnvm::FGradient;
  using MirrorFun = std::function<int (const Node& node)>;
87
  using AttrHintFun = std::function<NodeEntry (const NodeEntry& src, const NodeEntry &like)>;
88

89
  CHECK_NE(src.attrs.count("grad_ys"), 0U)
90
      << "Gradient require grad_ys to be presented.";
91
  CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U)
92
      << "Gradient require grad_ys_out_grad to be presented.";
93
  CHECK_NE(src.attrs.count("grad_xs"), 0U)
94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
      << "Gradient require grad_xs to be presented.";
  const std::vector<NodeEntry>& ys =
      src.GetAttr<std::vector<NodeEntry> >("grad_ys");
  const std::vector<NodeEntry>& ys_out_grad =
      src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
  const std::vector<NodeEntry>& xs =
      src.GetAttr<std::vector<NodeEntry> >("grad_xs");
  using AggFun = std::function<NodeEntry (std::vector<NodeEntry>&& inputs)>;
  AggFun agg_fun = DefaultAggregateGradient;
  if (src.attrs.count("grad_aggregate_fun") != 0) {
    agg_fun = src.GetAttr<AggFun>("grad_aggregate_fun");
  }
  MirrorFun mirror_fun = nullptr;
  if (src.attrs.count("grad_mirror_fun") != 0) {
    mirror_fun = src.GetAttr<MirrorFun>("grad_mirror_fun");
  }
110 111 112 113
  AttrHintFun attr_hint_fun = nullptr;
  if (src.attrs.count("attr_hint_fun") != 0) {
    attr_hint_fun = src.GetAttr<AttrHintFun>("attr_hint_fun");
  }
114 115 116 117
  std::vector<const Op*> zero_ops;
  if (src.attrs.count("zero_ops") != 0) {
    zero_ops = src.GetAttr<std::vector<const Op*> >("zero_ops");
  }
118 119 120
  const Op* copy_op = (src.attrs.count("copy_op") != 0) ?
      Op::Get(src.GetAttr<std::string>("copy_op")) :
      nullptr;
121 122 123 124

  // topo sort
  std::vector<NodePtr> topo_order;
  std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
125

126 127 128 129 130 131 132 133 134
  DFSVisit(ys, [&](const NodePtr& node) {
      if (output_grads.count(node.get()) == 0) {
        output_grads[node.get()].resize(node->num_outputs());
      }
      topo_order.push_back(node);
    });

  CHECK_EQ(ys.size(), ys_out_grad.size());
  for (size_t i = 0; i < ys.size(); ++i) {
135 136
    NodeEntry ograd = ys_out_grad[i];
    output_grads[ys[i].node.get()][ys[i].index].grads = { ograd };
137 138
  }

139 140 141 142 143 144 145
  // Check that all xs are reachable from ys
  for (size_t i = 0; i < xs.size(); ++i) {
    CHECK(output_grads.find(xs[i].node.get()) != output_grads.end())
        << "Cannot differentiate with respect to the " << i+1 << "-th variable "
        << "because it is unreachable from the outputs.";
  }

146
  // construct mirror as memory reduction strategy if needed
147 148
  std::unordered_map<Node*, NodePtr> mirror_map;
  if (mirror_fun != nullptr) {
149 150
    for (const NodePtr& node_ptr : topo_order) {
      if (mirror_fun(*node_ptr)) {
151
        NodePtr new_node = Node::Create();
152
        *new_node = *node_ptr;
153 154 155 156 157 158 159
        new_node->attrs.name += "_mirror";
        for (auto& e : new_node->inputs) {
          e.node = mirror_map.at(e.node.get());
        }
        for (auto& n : new_node->control_deps) {
          n = mirror_map.at(n.get());
        }
160
        mirror_map[node_ptr.get()] = std::move(new_node);
161
      } else {
162
        mirror_map[node_ptr.get()] = node_ptr;
163 164 165 166 167 168
      }
    }
  }

  // traverse backward
  static auto& grad_fun_map = Op::GetAttr<FGradient>("FGradient");
169 170
  static auto& finfer_shape = Op::GetAttr<FInferShape>("FInferShape");

171 172 173 174 175
  std::vector<NodeEntry> out_agg_grads;
  for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) {
    const NodePtr& ptr = *rit;
    if (ptr->is_variable()) continue;
    out_agg_grads.clear();
176 177 178
    auto& out_grad_vec = output_grads.at(ptr.get());
    for (uint32_t i = 0; i < out_grad_vec.size(); ++i) {
      GradEntry& e = out_grad_vec[i];
179
      e.sum = agg_fun(std::move(e.grads));
180 181 182
      if (e.need_attr_hint && attr_hint_fun != nullptr) {
        e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i});
      }
183 184
      out_agg_grads.push_back(e.sum);
    }
185
    if ((*rit)->inputs.size() != 0) {
186
      NodePtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()));
187
      std::vector<NodeEntry> input_grads;
188 189
      // Check for FGradient
      if (grad_fun_map.contains(ptr->op())) {
190 191 192 193
        input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads);
        CHECK_EQ((*rit)->inputs.size(), input_grads.size())
            << "Gradient function not returning enough gradient";
      } else if (CheckGradAllZero(out_agg_grads, zero_ops)) {
194
        for (size_t i = 0; i < fwd_node->num_inputs(); ++i) {
195 196 197 198 199 200 201 202 203 204 205 206 207 208
          std::ostringstream os;
          if (1 == fwd_node->num_inputs()) {
            os << fwd_node->attrs.name << "_backward";
          } else {
            os << fwd_node->attrs.name << "_in" << i << "_backward";
          }
          auto p = Node::Create();
          p->attrs.op = zero_ops[0];
          p->attrs.name = os.str();
          p->inputs.push_back(fwd_node->inputs[i]);
          p->control_deps.emplace_back(fwd_node);
          if (p->op()->attr_parser != nullptr) {
            p->op()->attr_parser(&(p->attrs));
          }
209
          input_grads.emplace_back(p, 0, 0);
210 211 212 213 214
        }
      } else {
        LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable "
                   << "because it didn't register FGradient attribute.";
      }
215 216
      for (const auto& nodeEntry : input_grads)
        CHECK(nodeEntry.node);
217
      auto git = input_grads.begin();
218
      CHECK((*rit)->inputs.size() <= input_grads.size());
219
      for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
220
        auto& output_grad_entry = output_grads[it->node.get()][it->index];
221
        // if any of the backward op can do shape inference, the hint is not necessary.
222 223
        if (finfer_shape.contains(git->node->op())) {
          output_grad_entry.need_attr_hint = false;
224
        }
225
        output_grad_entry.grads.emplace_back(std::move(*git));
226
      }
227 228 229 230
    }
  }
  // take out the xs' grads
  Graph ret;
231 232 233
  ret.outputs.resize(xs.size());
  NodeEntryMap<std::pair<size_t, size_t> > unique_grads;
  size_t counter = 0;
234 235 236 237 238
  for (const NodeEntry& e : xs) {
    GradEntry& entry = output_grads[e.node.get()][e.index];
    // aggregate sum if there haven't been
    if (entry.sum.node.get() == nullptr) {
      entry.sum = agg_fun(std::move(entry.grads));
239 240 241
      if (entry.need_attr_hint && attr_hint_fun != nullptr) {
        entry.sum = attr_hint_fun(entry.sum, e);
      }
242
    }
243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
    if (copy_op != nullptr) {
      auto kv = unique_grads.find(entry.sum);
      if (kv == unique_grads.end()) {
        unique_grads.emplace(std::move(entry.sum), std::make_pair(1, counter));
      } else {
        NodePtr copy_node = Node::Create();
        std::ostringstream os;
        os << entry.sum.node->attrs.name << "_" << kv->second.first << "_copy";
        kv->second.first++;
        copy_node->attrs.op = copy_op;
        copy_node->attrs.name = os.str();
        copy_node->inputs.emplace_back(entry.sum);
        if (copy_node->attrs.op->attr_parser != nullptr) {
            copy_node->attrs.op->attr_parser(&(copy_node->attrs));
        }
        unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter));
      }
    } else {
        ret.outputs[counter] = entry.sum;
    }
    ++counter;
  }
  if (copy_op != nullptr) {
    for (const auto& kv : unique_grads) {
      ret.outputs[kv.second.second] = kv.first;
    }
269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
  }
  return ret;
}

// register pass
NNVM_REGISTER_PASS(Gradient)
.describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]")
.set_body(Gradient)
.set_change_graph(true)
.depend_graph_attr("grad_ys")
.depend_graph_attr("grad_xs")
.depend_graph_attr("grad_ys_out_grad");

}  // namespace
}  // namespace pass
}  // namespace nnvm