symbolic.cc 21.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 * 
 *   http://www.apache.org/licenses/LICENSE-2.0
 * 
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21 22 23 24 25 26 27 28 29 30 31
/*!
 *  Copyright (c) 2016 by Contributors
 * \file symbolic.cc
 * \brief Symbolic graph composition API.
 */
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <nnvm/op_attr_types.h>

namespace nnvm {

namespace symbol_constants {
32
const char *kNamespaceSeparator = "$";
33 34
}  // namespace symbol_constants

35 36 37 38 39
// auxililary version attribute in variable.
struct VariableParam {
  uint32_t version{0};
};

40 41
NodePtr CreateVariableNode(const std::string& name) {
  NodePtr n = Node::Create();
42
  n->attrs.op = nullptr;
43 44 45 46 47 48 49 50 51 52
  n->attrs.name = name;
  n->attrs.parsed = VariableParam();
  return n;
}

// scan over a node's input, update the version to latest
// If the node's op mutates a certain input variable,
// The version of that varaible will increase
// version is used to implicitly order the mutation sequences
inline void UpdateNodeVersion(Node *n) {
53
  static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
54 55 56 57 58
  for (NodeEntry& e : n->inputs) {
    if (e.node->is_variable()) {
      e.version = nnvm::get<VariableParam>(e.node->attrs.parsed).version;
    }
  }
59 60
  if (fmutate_inputs.count(n->op()) != 0) {
    for (uint32_t i : fmutate_inputs[n->op()](n->attrs)) {
61 62 63 64 65
      NodeEntry& e = n->inputs[i];
      CHECK(e.node->is_variable())
          << "Mutation target can only be Variable";
      // increase the version of the variable.
      e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
66 67 68
    }
  }
}
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108

inline std::string DefaultVarName(const std::string &op_name,
                                  const std::string &arg_name) {
  if (op_name.length() == 0) {
    return arg_name;
  } else {
    return op_name + '_' + arg_name;
  }
}

inline void KeywordArgumentMismatch(const char *source,
                                    const std::vector<std::string>& user_args,
                                    const array_view<std::string>& args) {
  std::unordered_set<std::string> keys(args.begin(), args.end());
  std::ostringstream head, msg;
  msg << "\nCandidate arguments:\n";
  for (size_t i = 0; i < args.size(); ++i) {
    msg << "\t[" << i << ']' << args[i] << '\n';
  }

  for (const auto& key : user_args) {
    if (keys.count(key) == 0) {
      LOG(FATAL) << source
                 << "Keyword argument name " << key << " not found."
                 << msg.str();
    }
  }
}

template<typename T>
inline std::vector<std::string> GetKeys(
    const std::unordered_map<std::string, T>& kwargs) {
  std::vector<std::string> keys(kwargs.size());
  std::transform(kwargs.begin(), kwargs.end(), keys.begin(),
                 [](decltype(*kwargs.begin())& kv) { return kv.first; });
  return keys;
}

// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
109 110 111 112 113
  Node* node = outputs[0].node.get();
  for (const NodeEntry& e : outputs) {
    if (node != e.node.get()) return false;
  }
  return node->inputs.size() == 0 && node->control_deps.size() == 0;
114 115 116 117
}

// public functions
Symbol Symbol::Copy() const {
118
  std::unordered_map<Node*, NodePtr> old_new;
119
  // use DFSVisit to copy all the nodes
120 121
  DFSVisit(this->outputs, [&old_new](const NodePtr& node) {
      NodePtr np = Node::Create();
122 123
      np->attrs = node->attrs;
      old_new[node.get()] = std::move(np);
124 125 126 127 128
    });
  // connect nodes of new graph
  for (const auto &kv : old_new) {
    for (const NodeEntry& e : kv.first->inputs) {
      Node *ptr = e.node.get();
129
      kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
130
    }
131
    for (const NodePtr& p : kv.first->control_deps) {
132 133
      kv.second->control_deps.emplace_back(old_new[p.get()]);
    }
134 135 136 137
  }
  // set the head
  Symbol ret;
  for (const NodeEntry &e : outputs) {
138
    ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index, e.version});
139 140 141 142 143
  }
  return ret;
}

void Symbol::Print(std::ostream &os) const {
144 145 146
  if (outputs.size() == 1 &&
      outputs[0].node->inputs.size() == 0 &&
      outputs[0].node->control_deps.size() == 0) {
147 148 149
    if (outputs[0].node->is_variable()) {
      os << "Variable:" << outputs[0].node->attrs.name << '\n';
    } else {
150
      os << "AtomicFunctor "<< " Op:" << outputs[0].node->op()->name << '\n';
151
    }
152 153
  } else {
    // use DFSVisit to copy all the nodes
154
    os << "Symbol Outputs:\n";
155 156 157 158
    for (size_t i = 0; i < outputs.size(); ++i) {
      os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
         << '(' << outputs[i].index << ")\n";
    }
159
    DFSVisit(this->outputs, [&os](const NodePtr& node) {
160 161 162
        if (node->is_variable()) {
          os << "Variable:" << node->attrs.name << '\n';
        } else {
163
          os << "--------------------\n";
164
          os << "Op:" << node->op()->name << ", Name=" << node->attrs.name << '\n'
165 166
             << "Inputs:\n";
          for (size_t i = 0; i < node->inputs.size(); ++i) {
167 168 169 170 171 172 173 174
            const NodeEntry& e = node->inputs[i];
            os << "\targ[" << i << "]=" << e.node->attrs.name
               << '(' << e.index << ")";
            if (e.node->is_variable()) {
              os << " version=" << e.version << '\n';
            } else {
              os << '\n';
            }
175
          }
176 177
          if (!node->attrs.dict.empty()) {
            os << "Attrs:\n";
Tianqi Chen committed
178 179 180 181
            // make an ordered copy because unordered_map doesn't guarantee order.
            std::map<std::string, std::string> sorted_dict(
              node->attrs.dict.begin(), node->attrs.dict.end());
            for (auto &kv : sorted_dict) {
182 183 184 185 186 187 188 189
              os << '\t' << kv.first << '=' << kv.second << '\n';
            }
          }
          if (node->control_deps.size() != 0) {
            os << "Control deps:\n";
            for (size_t i = 0; i < node->control_deps.size(); ++i) {
              os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n';
            }
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207
          }
        }
      });
  }
}

Symbol Symbol::operator[] (size_t index) const {
  size_t nreturn = outputs.size();
  CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index";
  if (nreturn == 1) {
    return *this;
  } else {
    Symbol s;
    s.outputs.push_back(outputs[index]);
    return s;
  }
}

208 209
std::vector<NodePtr> Symbol::ListInputs(ListInputOption option) const {
  std::vector<NodePtr> ret;
210
  if (option == kAll) {
211
    ret.reserve(this->outputs.size());
212 213
    DFSVisit(this->outputs, [&ret](const NodePtr &node) {
        if (node->is_variable()) {
214
          ret.push_back(node);
215 216 217 218
        }
      });
  } else {
    std::unordered_set<Node*> mutable_set;
219
    std::vector<NodePtr> vlist;
220
    vlist.reserve(this->outputs.size());
221
    static auto& fmutate_inputs = Op::GetAttr<FMutateInputs>("FMutateInputs");
222
    DFSVisit(this->outputs, [&mutable_set, &vlist](const NodePtr &node) {
223
        if (node->is_variable()) {
224
          vlist.push_back(node);
225 226
        } else if (fmutate_inputs.count(node->op())) {
          for (uint32_t i : fmutate_inputs[node->op()](node->attrs)){
227
            mutable_set.insert(node->inputs[i].node.get());
228 229 230
          }
        }
      });
231
    ret.reserve(vlist.size());
232 233 234 235
    for (const NodePtr& node : vlist) {
      if ((option == kReadOnlyArgs && mutable_set.count(node.get()) == 0) ||
          (option == kAuxiliaryStates && mutable_set.count(node.get()) != 0)) {
        ret.emplace_back(node);
236
      }
237 238
    }
  }
239 240 241
  return ret;
}

242 243 244 245 246 247 248 249 250
std::vector<std::string> Symbol::ListInputNames(ListInputOption option) const {
  std::vector<NodePtr> inputs = ListInputs(option);
  std::vector<std::string> ret(inputs.size());
  for (size_t i = 0; i < inputs.size(); ++i) {
    ret[i] = inputs[i]->attrs.name;
  }
  return ret;
}

251
std::vector<std::string> Symbol::ListOutputNames() const {
252
  static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
253

254
  std::vector<std::string> ret;
255
  ret.reserve(outputs.size());
256 257 258 259 260 261
  for (auto &head : outputs) {
    if (head.node->is_variable()) {
      ret.push_back(head.node->attrs.name);
    } else {
      const std::string& hname = head.node->attrs.name;
      std::string rname;
262
      FListOutputNames fn = flist_ouputs.get(head.node->op(), nullptr);
263
      if (fn != nullptr) {
264
        rname = fn(head.node->attrs)[head.index];
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
      } else {
        rname = "output";
        if (head.node->num_outputs() != 1) {
          std::ostringstream os;
          os << rname << head.index;
          rname = os.str();
        }
      }
      if (hname.length() == 0) {
        ret.push_back(std::move(rname));
      } else {
        ret.push_back(hname + '_' + rname);
      }
    }
  }
  return ret;
}

// compositional logic
284 285
void Symbol::Compose(const array_view<const Symbol*>& args,
                     const std::unordered_map<std::string, const Symbol*>& kwargs,
286
                     const std::string& name) {
287
  static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
288
  static auto& fset_attrs = Op::GetAttr<FSetInputVarAttrOnCompose>("FSetInputVarAttrOnCompose");
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306
  static auto& fgraph = Op::GetAttr<FInputGraph>("FInputGraph");

  // The arguments that contain graphs.
  Node* n = outputs[0].node.get();
  FInputGraph fng = fgraph.get(n->op(), nullptr);
  std::vector<uint32_t> garg_idx;
  if (fng != nullptr)
    garg_idx = fng(n->attrs);

  // The names of the arguments that contain graphs.
  FListInputNames name_fn = flist_inputs.get(n->op(), nullptr);
  auto arg_names = (name_fn == nullptr) ? std::vector<std::string>{"data"} : name_fn(n->attrs);
  std::vector<std::string> garg_names(garg_idx.size());
  for (size_t i = 0; i < garg_idx.size(); i++) {
    size_t idx = garg_idx[i];
    if (idx < arg_names.size())
      garg_names[i] = arg_names[idx];
  }
307

308 309
  // parameter check.
  for (size_t i = 0; i < args.size(); ++i) {
310 311 312
    // If the argument isn't a graph, it should have only one output.
    if (garg_idx.empty() || std::find(garg_idx.begin(), garg_idx.end(), i) == garg_idx.end())
      CHECK_EQ(args[i]->outputs.size(), 1U)
313 314 315
        << "Argument " << i << " is a tuple, single value is required";
  }
  for (const auto& kv : kwargs) {
316 317 318
    if (garg_names.empty()
        || std::find(garg_names.begin(), garg_names.end(), kv.first) == garg_names.end())
      CHECK_EQ(kv.second->outputs.size(), 1U)
319 320
        << "Keyword Argument " << kv.first << " is a tuple, single value is required";
  }
Eric Junyuan Xie committed
321 322
  // assign new name
  if (!name.empty()) outputs[0].node->attrs.name = name;
323 324 325 326

  // Atomic functor composition.
  if (IsAtomic(outputs)) {
    uint32_t n_req = n->num_inputs();
327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346
    std::vector<const Symbol *> arg_vec(args.begin(), args.end());
    std::unordered_map<std::string, const Symbol*> kwarg_map(kwargs.begin(), kwargs.end());
    // If one of the input arguments is a graph, we need to remove it from the
    // list.
    if (fng != nullptr) {
      std::vector<uint32_t> idxes = fng(n->attrs);
      for (auto idx : idxes) {
        const Symbol *sym;
        if (idx < arg_vec.size()) {
          sym = arg_vec[idx];
        } else {
          auto it = kwarg_map.find(arg_names[idx]);
          CHECK(it != kwarg_map.end());
          sym = it->second;
          kwarg_map.erase(it);
        }
        if (n_req != kVarg)
          n_req--;
        n->attrs.subgraphs.push_back(std::make_shared<Symbol>(*sym));
      }
347 348 349 350 351 352 353 354 355 356
      // Because idxes does not contain duplicates, the loop below functions well.
      // Note that it is as slow as O(|idxes| * |args|),
      // but given that |idxes| is small, it is just fine
      sort(std::begin(idxes), std::end(idxes), std::greater<int>());
      for (auto idx : idxes) {
        if (idx < arg_vec.size()) {
          arg_vec.erase(arg_vec.begin() + idx);
        }
        arg_names.erase(arg_names.begin() + idx);
      }
357
    }
358 359 360

    if (n_req != kVarg) {
      n->inputs.resize(n_req);
361
      CHECK_LE(arg_vec.size(), n_req)
362
          << "Incorrect number of arguments, requires " << n_req
363 364 365
          << ", provided " << arg_vec.size();
      for (size_t i = 0; i < arg_vec.size(); ++i) {
        n->inputs[i] = arg_vec[i]->outputs[0];
366 367
      }
      // switch to keyword argument matching
368
      if (arg_vec.size() != n_req) {
369
        if (arg_names.size() != n_req) {
370
          LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op()->name;
371
        }
372
        size_t nmatched = 0;
373 374 375
        for (size_t i = arg_vec.size(); i < n_req; ++i) {
          auto it = kwarg_map.find(arg_names[i]);
          if (it != kwarg_map.end() && it->first == arg_names[i]) {
376
            n->inputs[i] = it->second->outputs[0];
377 378
            ++nmatched;
          } else {
379 380
            n->inputs[i] = NodeEntry{
              CreateVariableNode(DefaultVarName(name, arg_names[i])), 0, 0};
381 382
            // copy attribute of parent over automatically created variables
            n->inputs[i].node->attrs.dict = n->attrs.dict;
383 384 385
          }
        }

386
        if (nmatched != kwarg_map.size()) {
387
          n->inputs.clear();
388 389
          std::vector<std::string> keys = GetKeys(kwarg_map);
          array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_vec.size(),
390 391 392 393 394
                                       dmlc::BeginPtr(arg_names) + arg_names.size());
          KeywordArgumentMismatch("Symbol.Compose", keys, view);
        }
      }
    } else {
395 396 397
      CHECK_EQ(kwarg_map.size(), 0U) << "Variable length function do not accept kwargs";
      n->inputs.reserve(arg_vec.size());
      for (const Symbol* s : arg_vec) {
398
        n->inputs.push_back(s->outputs[0]);
399 400
      }
    }
401
    UpdateNodeVersion(n);
402 403 404 405 406 407 408 409 410

    FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr);
    if (fn != nullptr) {
      for (size_t i = 0; i < n->inputs.size(); ++i) {
        if (n->inputs[i].node->is_variable()) {
          fn(n->attrs, n->inputs[i].node, i);
        }
      }
    }
411 412
  } else {
    // general composition
413
    CHECK_EQ(args.size(), 0U)
414 415 416 417 418 419
        << "General composition only support kwargs for now";
    size_t nmatched = 0;
    size_t arg_counter = 0;
    std::unordered_map<Node *, const NodeEntry*> replace_map;
    // replace map stores the existing replacement plan for arguments node
    auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
420
        (const NodePtr &node) {
421 422
      if (node->is_variable()) {
        if (arg_counter < args.size()) {
423
          replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
424 425 426 427 428
          ++arg_counter;
        } else {
            // match kwargs
          auto kit = kwargs.find(node->attrs.name);
          if (kit != kwargs.end()) {
429
            replace_map[node.get()] = &(kit->second->outputs[0]);
430 431 432 433 434 435 436
            ++nmatched;
          }
        }
      }
    };
    DFSVisit(this->outputs, find_replace_map);

Tianqi Chen committed
437
    if (nmatched == kwargs.size() && arg_counter <= args.size()) {
438
      std::vector<Node*> update_nodes;
439
      std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
440
      auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes]
441
          (const NodePtr &node) {
442
        // visit all the childs, find possible replacement
443
        bool repl = false;
444 445 446 447 448 449
        for (size_t i = 0; i < node->inputs.size(); ++i) {
          NodeEntry *e = &(node->inputs[i]);
          if (e->node->is_variable()) {
            auto iter = replace_map.find(e->node.get());
            if (iter != replace_map.end()) {
              replace_plan.push_back(std::make_pair(e, iter->second));
450
              repl = true;
451 452 453
            }
          }
        }
454
        if (repl) update_nodes.push_back(node.get());
455 456 457 458 459 460
      };
      DFSVisit(this->outputs, find_replace_plan);

      for (const auto& kv : replace_plan) {
        *(kv.first) = *(kv.second);
      }
461 462 463
      for (Node* n : update_nodes) {
        UpdateNodeVersion(n);
      }
464 465
    } else {
      std::vector<std::string> keys = GetKeys(kwargs);
466
      std::vector<std::string> arg_names = ListInputNames(kAll);
467 468
      array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter,
                                   dmlc::BeginPtr(arg_names) + arg_names.size());
469
      KeywordArgumentMismatch("Symbol.Compose", keys, arg_names);
470
    }
471 472 473 474 475 476 477 478 479

    // update outputs in case the composed variable is part of outputs.
    for (size_t i = 0; i < outputs.size(); ++i) {
      if (outputs[i].node->is_variable()) {
        CHECK_EQ(args.size(), 0) << "Variable composition only supports keyword arguments";
        const auto it = kwargs.find(outputs[i].node->attrs.name);
        if (it != kwargs.end()) outputs[i] = it->second->outputs[0];
      }
    }
480 481 482
  }
}

483 484
Symbol Symbol::operator () (const array_view<const Symbol*>& args,
                            const std::unordered_map<std::string, const Symbol*>& kwargs,
485 486 487 488 489 490 491
                            const std::string& name) const {
  Symbol s = this->Copy();
  s.Compose(args, kwargs, name);
  return s;
}

void Symbol::AddControlDeps(const Symbol& src) {
492
  CHECK_EQ(outputs.size(), 1U)
493 494 495 496 497 498 499 500
      << "AddControlDeps only works for nongrouped symbol";
  Node* n = outputs[0].node.get();
  for (const NodeEntry& sp : src.outputs) {
    n->control_deps.push_back(sp.node);
  }
}

Symbol Symbol::GetInternals() const {
501
  static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
502
  Symbol ret;
503
  DFSVisit(this->outputs, [&ret](const NodePtr& node) {
504
      Node* n = node.get();
505 506 507 508 509 510
      if (n->is_variable()) {
        // grab version from variable.
        VariableParam& param = nnvm::get<VariableParam>(n->attrs.parsed);
        ret.outputs.emplace_back(NodeEntry{node, 0, param.version});
      } else {
        uint32_t nout = n->num_outputs();
511 512 513
        if (fnum_vis_output.count(n->op())) {
          nout = fnum_vis_output[n->op()](n->attrs);
        }
514 515 516
        for (uint32_t i = 0; i < nout; ++i) {
          ret.outputs.emplace_back(NodeEntry{node, i, 0});
        }
517 518 519 520 521
      }
    });
  return ret;
}

522 523 524 525 526 527 528 529 530 531 532 533
Symbol Symbol::GetChildren() const {
  Symbol ret;
  std::unordered_set<Node*> visited;
  for (const auto& p : this->outputs) {
    Node* node = p.node.get();
    if (visited.count(node)) continue;
    visited.insert(node);
    ret.outputs.insert(ret.outputs.end(), node->inputs.begin(), node->inputs.end());
  }
  return ret;
}

534
void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) {
535 536 537 538 539
  Node* node = outputs[0].node.get();
  for (const NodeEntry& e : outputs) {
    CHECK(node == e.node.get())
        << "Symbol.SetAttrs only works for non-grouped symbol";
  }
540
  for (const auto& kv : attrs) {
541 542 543 544 545 546
    if (kv.first == "name") {
      node->attrs.name = kv.second;
    } else {
      node->attrs.dict[kv.first] = kv.second;
    }
  }
547 548
  if (node->op() != nullptr && node->op()->attr_parser != nullptr) {
    node->op()->attr_parser(&(node->attrs));
549 550 551 552 553 554
  }
}

bool Symbol::GetAttr(const std::string& key, std::string* out) const {
  Node* node = outputs[0].node.get();
  for (const NodeEntry& e : outputs) {
Eric Junyuan Xie committed
555
    if (node != e.node.get()) return false;
556
  }
557 558 559
  if (key == "name") {
    *out = node->attrs.name;
    return true;
560
  } else if (key == "op_name") {
561 562 563
    if (node->attrs.op != nullptr) {
      *out = node->attrs.op->name;
    } else {
564
      *out = "null";  // use null with json
565
    }
566
    return true;
567 568 569 570 571 572 573 574 575
  } else if (key == "_value_index") {
    *out = "";
    for (size_t i = 0; i < outputs.size(); ++i) {
      if (i != 0) {
        *out += ", ";
      }
      *out += std::to_string(outputs[i].index);
    }
    return true;
576
  }
577 578 579 580
  auto it = node->attrs.dict.find(key);
  if (it == node->attrs.dict.end()) return false;
  *out = it->second;
  return true;
581 582
}

583
std::unordered_map<std::string, std::string> Symbol::ListAttrs(ListAttrOption option) const {
584 585
  if (option == kRecursive) {
    std::unordered_map<std::string, std::string> ret;
586
    DFSVisit(this->outputs, [&ret](const NodePtr& n) {
587 588 589 590 591 592 593 594 595 596
        for (const auto& it : n->attrs.dict) {
          ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
        }
      });
    return ret;
  } else {
    return outputs[0].node->attrs.dict;
  }
}

597 598 599 600 601 602 603 604 605 606 607
std::vector<std::tuple<std::string, std::string, std::string> >
    Symbol::ListAttrsRecursive() const {
  std::vector<std::tuple<std::string, std::string, std::string> > ret;
  DFSVisit(this->outputs, [&ret](const NodePtr& n) {
      for (const auto& it : n->attrs.dict) {
        ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second));
      }
    });
  return ret;
}

608
Symbol Symbol::CreateFunctor(const Op* op,
tqchen committed
609
                             std::unordered_map<std::string, std::string> attrs) {
610
  static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
611
  Symbol s;
612
  NodePtr n = Node::Create();
613
  n->attrs.op = op;
614
  n->attrs.dict = std::move(attrs);
615 616
  if (n->op()->attr_parser != nullptr) {
    n->op()->attr_parser(&(n->attrs));
617
  }
618 619 620 621 622

  uint32_t nout = n->num_outputs();
  if (fnum_vis_output.count(n->op())) {
    nout = fnum_vis_output[n->op()](n->attrs);
  }
623 624
  for (size_t i = 0; i < nout; i++) {
    s.outputs.emplace_back(n, i, 0);
625
  }
626 627 628
  return s;
}

629 630 631 632 633 634 635 636 637 638 639
Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) {
  static auto& fnum_vis_output = Op::GetAttr<FNumVisibleOutputs>("FNumVisibleOutputs");
  Symbol s;
  NodePtr n = Node::Create();
  n->attrs = attrs;

  uint32_t nout = n->num_outputs();
  if (fnum_vis_output.count(n->op())) {
    nout = fnum_vis_output[n->op()](n->attrs);
  }
  for (uint32_t i = 0; i < nout; ++i) {
640
    s.outputs.emplace_back(n, i, 0);
641 642 643 644
  }
  return s;
}

645 646 647 648 649 650 651 652 653 654
Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
  Symbol ret;
  for (const auto &s : symbols) {
    ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end());
  }
  return ret;
}

Symbol Symbol::CreateVariable(const std::string& name) {
  Symbol s;
655
  s.outputs.emplace_back(CreateVariableNode(name), 0, 0);
656 657 658 659
  return s;
}

}  // namespace nnvm