Commit 93534ece by Tianqi Chen

[SYMBOLIC] Add symbolic API (#2)

* [SYMBOLIC] Add symbolic API

* Update Testcase to nnvm
parent 5d407324
......@@ -14,11 +14,13 @@
#include <unordered_set>
#include "./base.h"
#include "./node.h"
#include "./symbolic.h"
namespace nnvm {
/*!
* \brief Symbolic computation graph.
* This is the intermediate representation for optimization pass.
*/
class Graph {
public:
......@@ -30,15 +32,17 @@ class Graph {
* and can be shared across multiple Instance of graph
*/
std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
/*!
};
/*!
* \brief perform a Post Order DFS visit to each node in the graph.
* This order is deterministic and is also topoligical sorted.
* \param heads The heads in the graph.
* \param fvisit a function of type std::function<void(const std::shared_ptr<Node>&)>
* \tparam FVisit The function type to perform the visit.
*/
template<typename FVisit>
inline void DFSVisit(FVisit fvisit) const;
};
template<typename FVisit>
inline void DFSVisit(const std::vector<NodeEntry>& heads, FVisit fvisit);
// inline function implementations
template <typename GNode, typename HashType,
......@@ -75,10 +79,11 @@ void PostOrderDFSVisit(const std::vector<GNode>& heads,
}
template<typename FVisit>
inline void Graph::DFSVisit(FVisit fvisit) const {
inline void DFSVisit(const std::vector<NodeEntry>& heads,
FVisit fvisit) {
typedef const std::shared_ptr<Node>* GNode;
std::vector<GNode> head_nodes(outputs.size());
std::transform(outputs.begin(), outputs.end(), head_nodes.begin(),
std::vector<GNode> head_nodes(heads.size());
std::transform(heads.begin(), heads.end(), head_nodes.begin(),
[](const NodeEntry& e)->GNode {
return &e.node;
});
......
......@@ -72,6 +72,8 @@ class Node {
inline bool is_variable() const;
/*! \return number of outputs from this node */
inline uint32_t num_outputs() const;
/*! \return number of inputs from this node */
inline uint32_t num_inputs() const;
/*!
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
......@@ -86,10 +88,19 @@ inline bool Node::is_variable() const {
inline uint32_t Node::num_outputs() const {
if (is_variable()) return 1;
if (this->op->num_outputs >= 0) {
return static_cast<uint32_t>(this->op->num_outputs);
if (this->op->get_num_outputs == nullptr) {
return this->op->num_outputs;
} else {
return this->op->get_num_outputs(*this);
return this->op->get_num_outputs(this->attrs);
}
}
inline uint32_t Node::num_inputs() const {
if (is_variable()) return 1;
if (this->op->get_num_inputs == nullptr) {
return this->op->num_inputs;
} else {
return this->op->get_num_inputs(this->attrs);
}
}
......
......@@ -10,6 +10,7 @@
#include <vector>
#include <utility>
#include <typeinfo>
#include <limits>
#include <functional>
#include "./base.h"
......@@ -22,8 +23,8 @@ template<typename ValueType>
class OpMap;
class OpRegistryEntry;
/*! \brief constant to indicate variable length inout and output */
static const int kVarg = -1;
/*! \brief constant to indicate it take any length of positional inputs */
static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
/*!
* \brief Operator structure.
......@@ -79,23 +80,31 @@ class Op {
/*!
* \brief number of inputs to the operator,
* -1 means it is variable length
* When get_num_inputs is presented,
* the number will be decided by get_num_inputs instead.
* \sa get_num_inputs
*/
int num_inputs = 0;
uint32_t num_inputs = 1;
/*!
* \brief number of outputs of the operator
* -1 means it is variable length
* When get_num_outputs is presented.
* The number of outputs will be decided by
* get_num_outputs function
* \sa get_num_outputs
*/
int num_outputs = 1;
uint32_t num_outputs = 1;
/*!
* \brief get number of outputs given information about the node.
* This is only valid when num_outputs == -1.
* \param node The constructed node.
* \param attrs The attribute of the node
* \return number of outputs.
*/
int (*get_num_outputs)(const Node& node) = nullptr;
uint32_t (*get_num_outputs)(const NodeAttrs& attrs) = nullptr;
/*!
* \brief get number of inputs given information about the node.
* \param attrs The attribute of the node
* \return number of inputs
*/
uint32_t (*get_num_inputs)(const NodeAttrs& attrs) = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
......@@ -143,19 +152,25 @@ class Op {
* \param n The number of inputs to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(int n); // NOLINT(*)
inline Op& set_num_inputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int n); // NOLINT(*)
inline Op& set_num_outputs(uint32_t n); // NOLINT(*)
/*!
* \brief Set the get_num_outputs function.
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(int (*fn)(const Node& node)); // NOLINT(*)
inline Op& set_num_outputs(uint32_t (*fn)(const NodeAttrs& attr)); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
......@@ -180,6 +195,7 @@ class Op {
static const Op* Get(const std::string& op_name);
/*!
* \brief Get additional registered attribute about operators.
* If nothing has been registered, an empty OpMap will be returned.
* \param attr_name The name of the attribute.
* \return An OpMap of specified attr_name.
* \tparam ValueType The type of the attribute.
......@@ -197,7 +213,7 @@ class Op {
// internal constructor
Op();
// get const reference to certain attribute
static const any& GetAttrMap(const std::string& key);
static const any* GetAttrMap(const std::string& key);
// update the attribute OpMap
static void UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater);
......@@ -218,6 +234,13 @@ class OpMap {
*/
inline const ValueType& operator[](const Op* op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
* \param def_value The default value when the key does not exist.
* \return the const reference to the content value.
*/
inline const ValueType& get(const Op* op, const ValueType& def_value) const;
/*!
* \brief Check if the map has op as key.
* \param op The key to the map
* \return 1 if op is contained in map, 0 otherwise.
......@@ -262,8 +285,18 @@ class OpMap {
// member function of Op
template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any& ref = GetAttrMap(key);
return nnvm::get<OpMap<ValueType> >(ref);
const any* ref = GetAttrMap(key);
if (ref == nullptr) {
UpdateAttrMap(key, [key](any* pmap) {
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = key;
*pmap = std::move(pm);
}
});
ref = GetAttrMap(key);
}
return nnvm::get<OpMap<ValueType> >(*ref);
}
template<typename ValueType>
......@@ -273,7 +306,7 @@ inline Op& Op::attr( // NOLINT(*)
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = attr_name;
*pmap = pm;
*pmap = std::move(pm);
}
CHECK_EQ(pmap->type(), typeid(OpMap<ValueType>))
<< "Attribute " << attr_name
......@@ -301,18 +334,22 @@ inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
return *this;
}
inline Op& Op::set_num_inputs(int n) { // NOLINT(*)
inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
this->num_inputs = n;
return *this;
}
inline Op& Op::set_num_outputs(int n) { // NOLINT(*)
inline Op& Op::set_num_inputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}
inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
this->num_outputs = n;
return *this;
}
inline Op& Op::set_num_outputs(int (*fn)(const Node& node)) { // NOLINT(*)
this->num_outputs = kVarg;
inline Op& Op::set_num_outputs(uint32_t (*fn)(const NodeAttrs&)) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}
......@@ -338,6 +375,16 @@ inline const ValueType& OpMap<ValueType>::operator[](const Op* op) const {
return data_[idx].first;
}
template<typename ValueType>
inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def_value) const {
const uint32_t idx = op->index_;
if (idx < data_.size() && data_[idx].second) {
return data_[idx].first;
} else {
return def_value;
}
}
} // namespace nnvm
#endif // NNVM_OP_H_
/*!
* Copyright (c) 2016 by Contributors
* \file op_attr_types.h
* \brief Data structures that can appear in operator attributes.
*/
#ifndef NNVM_OP_ATTR_TYPES_H_
#define NNVM_OP_ATTR_TYPES_H_
#include <vector>
#include <string>
#include <functional>
namespace nnvm {
// These types are optional attributes in each op
// Some of them are needed for certain pass.
/*!
* \brief Return list of input arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
/*!
* \brief Return list of output arguments names of each operator.
*
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;
} // namespace nnvm
#endif // NNVM_OP_ATTR_TYPES_H_
/*!
* Copyright (c) 2016 by Contributors
* \file symbolic.h
* \brief Symbolic graph construction API
*
* This API is optional, but useful to allow user
* to construct NNVM Graph easily, and quickly create
* front-end host languages.
*/
#ifndef NNVM_SYMBOLIC_H_
#define NNVM_SYMBOLIC_H_
#include <string>
#include <vector>
#include <utility>
#include "./base.h"
#include "./node.h"
namespace nnvm {
/*!
* \brief Symbol is used to represent the
*/
class Symbol {
public:
/*! \brief option passed to ListAttr */
enum ListAttrOption {
/*! \brief recursively list all attributes */
kRecursive,
/*! \brief only list attributes in current node */
kShallow
};
/*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs;
/*!
* \brief copy the symbol
* \return a deep copy of the symbolic graph.
*/
Symbol Copy() const;
/*!
* \brief print the symbol info to output stream.
* \param os the output stream we like to print to
*/
void Print(std::ostream &os) const; // NOLINT(*)
/*!
* \brief get the index th element from the returned tuple.
* \param index index of multi output
* \return the symbol corresponds to the indexed element.
*/
Symbol operator[] (size_t index) const;
/*!
* \brief List the arguments names.
*
* The position of the returned list also corresponds to calling position in operator()
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
*/
std::vector<std::string> ListArguments() const;
/*!
* \brief List the names of outputs for this symbol.
* For normal operators, it is usually symbol node name + "_output"
* \return get the descriptions of outputs for this symbol.
*/
std::vector<std::string> ListOutputs() const;
/*!
* \brief Compose the symbol with arguments, this changes the current symbol.
* The kwargs passed in can be in-complete,
*
* The rest of the symbols will remain the same name.
*
* \param positional arguments
* \param kwargs keyword arguments for the symbol
* \param name name of returned symbol.
*/
void Compose(const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name);
/*!
* \brief Apply the symbol as a function, compose with arguments
* This is equivalent to Copy then Compose.
* \param args positional arguments for the symbol
* \param kwargs keyword arguments for the symbol
* \param name name of returned symbol.
* \return a new Symbol which is the composition of current symbol with its arguments
*/
Symbol operator () (const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) const;
/*!
* \brief Add control flow depenencies to operators involved in symbols.
* For grouped sybmbol, an error will be raised.
* This mutate current symbolic Node.
*
* \param src The symbols to depend on.
*/
void AddControlDeps(const Symbol& src);
/*
* \brief Get all the internal nodes of the symbol.
* \return symbol A new symbol whose output contains all the outputs of the symbols
* Including input variables and intermediate outputs.
*/
Symbol GetInternals() const;
/*!
* \brief set additional attributes to current node.
* This only works for symbol with outputs from single operators.
* For grouped sybmbol, an error will be raised.
*
* This function mutate the node's symbol and is not recommended.
*
* \param key the key of the attribute
* \param value the value of the attribute.
*/
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
/*!
* \brief Get attribute dictionary from the symbol.
* For grouped sybmbol, an error will be raised.
* \param option If recursive is set, the attributes of all children are retrieved,
* The name of symbol will be pre-pended to each key.
* \return The created attribute.
*/
std::unordered_map<std::string, std::string> ListAttr(ListAttrOption option) const;
/*!
* \brief create symbolic functor(AtomicSymbol) by given operator and attributes.
* \param op_name The name of the operator.
* \param attrs The additional attributes.
*
* \return Symbol that can be used to call compose further.
*/
static Symbol CreateFunctor(const std::string& op_name,
const std::unordered_map<std::string, std::string>& attrs);
/*!
* \brief create variable symbol node
* \param name name of the variable
* \return the new variable
*/
static Symbol CreateVariable(const std::string& name);
/*!
* \brief create equivalence of symbol by grouping the symbols together
* \param symbols list of symbols
* \return the grouped symbol
*/
static Symbol CreateGroup(const std::vector<Symbol>& symbols);
};
} // namespace nnvm
#endif // NNVM_SYMBOLIC_H_
......@@ -13,7 +13,7 @@ IndexedGraph::IndexedGraph(const Graph &g) {
entry_rptr_.push_back(0);
std::vector<size_t> inputs_rptr{0}, control_rptr{0};
g.DFSVisit([this, &inputs_rptr, &control_rptr]
DFSVisit(g.outputs, [this, &inputs_rptr, &control_rptr]
(const std::shared_ptr<nnvm::Node>& n) {
CHECK_LT(nodes_.size(), std::numeric_limits<uint32_t>::max());
uint32_t nid = static_cast<uint32_t>(nodes_.size());
......
......@@ -6,6 +6,7 @@
#include <nnvm/base.h>
#include <nnvm/op.h>
#include <memory>
#include <atomic>
#include <mutex>
......@@ -23,7 +24,7 @@ struct OpManager {
// global operator counter
std::atomic<int> op_counter{0};
// storage of additional attribute table.
std::unordered_map<std::string, any> attr;
std::unordered_map<std::string, std::unique_ptr<any> > attr;
// get singleton of the
static OpManager* Global() {
static OpManager inst;
......@@ -46,24 +47,24 @@ const Op* Op::Get(const std::string& name) {
}
// Get attribute map by key
const any& Op::GetAttrMap(const std::string& key) {
// assume no operator registration during
// the execution phase.
const auto& dict = OpManager::Global()->attr;
const any* Op::GetAttrMap(const std::string& key) {
auto& dict = OpManager::Global()->attr;
auto it = dict.find(key);
CHECK(it != dict.end() && it->first == key)
<< "Cannot find Operator attribute " << key
<< " for any operator";
return it->second;
if (it != dict.end()) {
return it->second.get();
} else {
return nullptr;
}
}
// update attribute map by updater function.
// update attribute map
void Op::UpdateAttrMap(const std::string& key,
std::function<void(any*)> updater) {
OpManager* mgr = OpManager::Global();
std::lock_guard<std::mutex>(mgr->mutex);
any& value = mgr->attr[key];
updater(&value);
std::unique_ptr<any>& value = mgr->attr[key];
if (value.get() == nullptr) value.reset(new any());
if (updater != nullptr) updater(value.get());
}
} // namespace nnvm
/*!
* Copyright (c) 2016 by Contributors
* \file symbolic.cc
* \brief Symbolic graph composition API.
*/
#include <nnvm/graph.h>
#include <nnvm/symbolic.h>
#include <nnvm/op_attr_types.h>
namespace nnvm {
namespace symbol_constants {
const char *kNamespaceSeparator = "_";
} // namespace symbol_constants
inline std::string DefaultVarName(const std::string &op_name,
const std::string &arg_name) {
if (op_name.length() == 0) {
return arg_name;
} else {
return op_name + '_' + arg_name;
}
}
inline void KeywordArgumentMismatch(const char *source,
const std::vector<std::string>& user_args,
const array_view<std::string>& args) {
std::unordered_set<std::string> keys(args.begin(), args.end());
std::ostringstream head, msg;
msg << "\nCandidate arguments:\n";
for (size_t i = 0; i < args.size(); ++i) {
msg << "\t[" << i << ']' << args[i] << '\n';
}
for (const auto& key : user_args) {
if (keys.count(key) == 0) {
LOG(FATAL) << source
<< "Keyword argument name " << key << " not found."
<< msg.str();
}
}
}
template<typename T>
inline std::vector<std::string> GetKeys(
const std::unordered_map<std::string, T>& kwargs) {
std::vector<std::string> keys(kwargs.size());
std::transform(kwargs.begin(), kwargs.end(), keys.begin(),
[](decltype(*kwargs.begin())& kv) { return kv.first; });
return keys;
}
// whether the symbol is atomic functor
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
return outputs.size() == 1 && outputs[0].node->inputs.size() == 0;
}
// public functions
Symbol Symbol::Copy() const {
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
// use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) {
old_new[node.get()] = std::make_shared<Node>(*node);
});
// connect nodes of new graph
for (const auto &kv : old_new) {
for (const NodeEntry& e : kv.first->inputs) {
Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index});
}
}
// set the head
Symbol ret;
for (const NodeEntry &e : outputs) {
ret.outputs.emplace_back(NodeEntry{old_new[e.node.get()], e.index});
}
return ret;
}
void Symbol::Print(std::ostream &os) const {
if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n';
} else {
// use DFSVisit to copy all the nodes
os << "Outputs:\n";
for (size_t i = 0; i < outputs.size(); ++i) {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n";
}
DFSVisit(this->outputs, [&os](const std::shared_ptr<Node>& node) {
if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n';
} else {
os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n'
<< "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) {
os << "\targ[" << i << "]=" << node->inputs[i].node->attrs.name
<< '(' << node->inputs[i].index << ")\n";
}
os << "Attrs:\n";
for (auto &kv : node->attrs.dict) {
os << '\t' << kv.first << '=' << kv.second << '\n';
}
}
});
}
}
Symbol Symbol::operator[] (size_t index) const {
size_t nreturn = outputs.size();
CHECK_LT(index, nreturn) << "Symbol only accept nonnegative index";
if (nreturn == 1) {
return *this;
} else {
Symbol s;
s.outputs.push_back(outputs[index]);
return s;
}
}
std::vector<std::string> Symbol::ListArguments() const {
std::vector<std::string> ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node> &node) {
if (node->is_variable()) {
ret.push_back(node->attrs.name);
}
});
return ret;
}
std::vector<std::string> Symbol::ListOutputs() const {
static auto& flist_ouputs = Op::GetAttr<FListOutputNames>("FListOutputNames");
std::vector<std::string> ret;
for (auto &head : outputs) {
if (head.node->is_variable()) {
ret.push_back(head.node->attrs.name);
} else {
const std::string& hname = head.node->attrs.name;
std::string rname;
FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr);
if (fn != nullptr) {
rname = fn(head.node->attrs)[head.index];
} else {
rname = "output";
if (head.node->num_outputs() != 1) {
std::ostringstream os;
os << rname << head.index;
rname = os.str();
}
}
if (hname.length() == 0) {
ret.push_back(std::move(rname));
} else {
ret.push_back(hname + '_' + rname);
}
}
}
return ret;
}
// compositional logic
void Symbol::Compose(const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) {
CHECK_EQ(outputs.size(), 1)
<< "Only composition of value function is supported currently";
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check.
for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i].outputs.size(), 1)
<< "Argument " << i << " is a tuple, single value is required";
}
for (const auto& kv : kwargs) {
CHECK_EQ(kv.second.outputs.size(), 1)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required";
}
// assign new name
outputs[0].node->attrs.name = name;
// Atomic functor composition.
if (IsAtomic(outputs)) {
Node* n = outputs[0].node.get();
uint32_t n_req = n->num_inputs();
if (n_req != kVarg) {
n->inputs.resize(n_req);
CHECK_LE(args.size(), n_req)
<< "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i].outputs[0];
}
// switch to keyword argument matching
if (args.size() != n_req) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
FListInputNames fn = flist_inputs.get(n->op, nullptr);
auto arg_names = (fn == nullptr) ? std::vector<std::string>{"data"} : fn(n->attrs);
CHECK_EQ(arg_names.size(), n_req);
size_t nmatched = 0;
for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second.outputs[0];
++nmatched;
} else {
n->inputs[i] = NodeEntry{Node::Create(), 0};
n->inputs[i].node->attrs.name = DefaultVarName(name, arg_names[i]);
}
}
if (nmatched != kwargs.size()) {
n->inputs.clear();
std::vector<std::string> keys = GetKeys(kwargs);
array_view<std::string> view(dmlc::BeginPtr(arg_names) + args.size(),
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, view);
}
}
} else {
CHECK_EQ(kwargs.size(), 0) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size());
for (const Symbol& s : args) {
n->inputs.push_back(s.outputs[0]);
}
}
} else {
// general composition
CHECK_EQ(args.size(), 0)
<< "General composition only support kwargs for now";
size_t nmatched = 0;
size_t arg_counter = 0;
std::unordered_map<Node *, const NodeEntry*> replace_map;
// replace map stores the existing replacement plan for arguments node
auto find_replace_map = [&nmatched, &arg_counter, &args, &kwargs, &replace_map]
(const std::shared_ptr<Node> &node) {
if (node->is_variable()) {
if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter].outputs[0]);
++arg_counter;
} else {
// match kwargs
auto kit = kwargs.find(node->attrs.name);
if (kit != kwargs.end()) {
replace_map[node.get()] = &(kit->second.outputs[0]);
++nmatched;
}
}
}
};
DFSVisit(this->outputs, find_replace_map);
if (nmatched == kwargs.size() && arg_counter < args.size()) {
std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
auto find_replace_plan = [&replace_map, &replace_plan]
(const std::shared_ptr<Node> &node) {
// visit all the childs, find possible replacement
for (size_t i = 0; i < node->inputs.size(); ++i) {
NodeEntry *e = &(node->inputs[i]);
if (e->node->is_variable()) {
auto iter = replace_map.find(e->node.get());
if (iter != replace_map.end()) {
replace_plan.push_back(std::make_pair(e, iter->second));
}
}
}
};
DFSVisit(this->outputs, find_replace_plan);
for (const auto& kv : replace_plan) {
*(kv.first) = *(kv.second);
}
} else {
std::vector<std::string> keys = GetKeys(kwargs);
std::vector<std::string> arg_names = ListArguments();
array_view<std::string> view(dmlc::BeginPtr(arg_names) + arg_counter,
dmlc::BeginPtr(arg_names) + arg_names.size());
KeywordArgumentMismatch("Symbol.Compose", keys, ListArguments());
}
}
}
Symbol Symbol::operator () (const std::vector<Symbol>& args,
const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) const {
Symbol s = this->Copy();
s.Compose(args, kwargs, name);
return s;
}
void Symbol::AddControlDeps(const Symbol& src) {
CHECK_EQ(outputs.size(), 1)
<< "AddControlDeps only works for nongrouped symbol";
Node* n = outputs[0].node.get();
for (const NodeEntry& sp : src.outputs) {
n->control_deps.push_back(sp.node);
}
}
Symbol Symbol::GetInternals() const {
Symbol ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& node) {
Node* n = node.get();
uint32_t nout = n->num_outputs();
for (uint32_t i = 0; i < nout; ++i) {
ret.outputs.emplace_back(NodeEntry{node, i});
}
});
return ret;
}
void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs) {
CHECK_EQ(outputs.size(), 1)
<< "SetAttrs only works for nongrouped symbol";
Node* n = outputs[0].node.get();
for (const auto& kv : attrs) {
n->attrs.dict[kv.first] = kv.second;
}
if (n->op->attr_parser != nullptr) {
(*n->op->attr_parser)(&(n->attrs));
}
}
std::unordered_map<std::string, std::string> Symbol::ListAttr(ListAttrOption option) const {
if (option == kRecursive) {
std::unordered_map<std::string, std::string> ret;
DFSVisit(this->outputs, [&ret](const std::shared_ptr<Node>& n) {
for (const auto& it : n->attrs.dict) {
ret[n->attrs.name + symbol_constants::kNamespaceSeparator + it.first] = it.second;
}
});
return ret;
} else {
return outputs[0].node->attrs.dict;
}
}
Symbol Symbol::CreateFunctor(const std::string& op_name,
const std::unordered_map<std::string, std::string>& attrs) {
Symbol s;
std::shared_ptr<Node> n = Node::Create();
n->op = Op::Get(op_name);
n->attrs.dict = attrs;
if (n->op->attr_parser != nullptr) {
(*n->op->attr_parser)(&(n->attrs));
}
s.outputs.emplace_back(NodeEntry{std::move(n), 0});
return s;
}
Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
Symbol ret;
for (const auto &s : symbols) {
ret.outputs.insert(ret.outputs.end(), s.outputs.begin(), s.outputs.end());
}
return ret;
}
Symbol Symbol::CreateVariable(const std::string& name) {
Symbol s;
std::shared_ptr<Node> n = Node::Create();
n->op = nullptr;
n->attrs.name = name;
s.outputs.emplace_back(NodeEntry{std::move(n), 0});
return s;
}
} // namespace nnvm
......@@ -158,7 +158,7 @@ Graph LoadJSON(const Graph& src) {
Graph SaveJSON(const Graph& src) {
JSONGraph jgraph;
std::unordered_map<Node*, uint32_t> node2index;
src.DFSVisit([&node2index, &jgraph](const std::shared_ptr<Node>& n) {
DFSVisit(src.outputs, [&node2index, &jgraph](const std::shared_ptr<Node>& n) {
uint32_t nid = static_cast<uint32_t>(jgraph.nodes.size());
node2index[n.get()] = nid;
if (n->is_variable()) {
......
......@@ -8,7 +8,7 @@
void test_op() {
using namespace nnvm;
auto add = Op::Get("add");
auto nick = Op::GetAttr<std::string>("nick_name");
static auto& nick = Op::GetAttr<std::string>("nick_name");
LOG(INFO) << "nick=" << nick[add];
}
......@@ -35,9 +35,7 @@ void test_tuple() {
void test_graph() {
nnvm::Graph g;
g.DFSVisit([](const std::shared_ptr<const nnvm::Node>& n){
});
nnvm::Symbol s;
}
int main() {
test_tuple();
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <nngraph/op.h>
#include <nnvm/op.h>
#include <utility>
NNGRAPH_REGISTER_OP(add)
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr("inplace_pair", std::make_pair(0, 0));
NNGRAPH_REGISTER_OP(add)
NNVM_REGISTER_OP(add)
.attr<std::string>("nick_name", "plus");
TEST(Op, GetAttr) {
using namespace nngraph;
using namespace nnvm;
auto add = Op::Get("add");
auto nick = Op::GetAttr<std::string>("nick_name");
......
#include <dmlc/logging.h>
#include <gtest/gtest.h>
#include <nngraph/tuple.h>
#include <nnvm/tuple.h>
TEST(Tuple, Basic) {
using nngraph::Tuple;
using nngraph::TShape;
using nnvm::Tuple;
using nnvm::TShape;
Tuple<int> x{1, 2, 3};
Tuple<int> y{1, 2, 3, 5, 6};
x = std::move(y);
......@@ -17,7 +17,7 @@ TEST(Tuple, Basic) {
std::istringstream is(os.str());
is >> y;
CHECK_EQ(x, y);
Tuple<nngraph::index_t> ss{1, 2, 3};
Tuple<nnvm::index_t> ss{1, 2, 3};
TShape s = ss;
s = std::move(ss);
CHECK((s == TShape{1, 2, 3}));
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment