Commit db9a9a79 by Tianqi Chen

[PASS] Add order mutation (#7)

* [PASS] Add order mutation

* A few benchmarks on compose speed
parent badcdfff
...@@ -73,8 +73,8 @@ class Symbol { ...@@ -73,8 +73,8 @@ class Symbol {
* \param kwargs keyword arguments for the symbol * \param kwargs keyword arguments for the symbol
* \param name name of returned symbol. * \param name name of returned symbol.
*/ */
void Compose(const std::vector<Symbol>& args, void Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, Symbol>& kwargs, const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name); const std::string& name);
/*! /*!
* \brief Apply the symbol as a function, compose with arguments * \brief Apply the symbol as a function, compose with arguments
...@@ -84,8 +84,8 @@ class Symbol { ...@@ -84,8 +84,8 @@ class Symbol {
* \param name name of returned symbol. * \param name name of returned symbol.
* \return a new Symbol which is the composition of current symbol with its arguments * \return a new Symbol which is the composition of current symbol with its arguments
*/ */
Symbol operator () (const std::vector<Symbol>& args, Symbol operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, Symbol>& kwargs, const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const; const std::string& name) const;
/*! /*!
* \brief Add control flow depenencies to operators involved in symbols. * \brief Add control flow depenencies to operators involved in symbols.
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
#include <dmlc/logging.h> #include <dmlc/logging.h>
#include <dmlc/thread_local.h> #include <dmlc/thread_local.h>
#include <nnvm/c_api.h> #include <nnvm/c_api.h>
#include <nnvm/symbolic.h>
#include <vector> #include <vector>
#include <string> #include <string>
...@@ -36,6 +37,8 @@ struct NNAPIThreadLocalEntry { ...@@ -36,6 +37,8 @@ struct NNAPIThreadLocalEntry {
std::vector<const char *> ret_vec_charp; std::vector<const char *> ret_vec_charp;
/*! \brief result holder for returning handles */ /*! \brief result holder for returning handles */
std::vector<void *> ret_handles; std::vector<void *> ret_handles;
/*! \brief argument holder to hold symbol */
std::unordered_map<std::string, const nnvm::Symbol*> kwarg_symbol;
}; };
/*! \brief Thread local store that can be used to hold return values. */ /*! \brief Thread local store that can be used to hold return values. */
......
...@@ -217,22 +217,26 @@ int NNSymbolCompose(SymbolHandle sym, ...@@ -217,22 +217,26 @@ int NNSymbolCompose(SymbolHandle sym,
const char** keys, const char** keys,
SymbolHandle* args) { SymbolHandle* args) {
API_BEGIN(); API_BEGIN();
std::string s_name; NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
if (name != nullptr) s_name = name; std::string& s_name = ret->ret_str;
std::unordered_map<std::string, const Symbol*>& kwargs
= ret->kwarg_symbol;
if (name != nullptr) {
s_name = name;
} else {
s_name.clear();
}
Symbol* s = static_cast<Symbol*>(sym); Symbol* s = static_cast<Symbol*>(sym);
if (keys == nullptr && num_args != 0) { if (keys == nullptr && num_args != 0) {
std::vector<Symbol> pos_args; kwargs.clear();
for (nn_uint i = 0; i < num_args; ++i) { array_view<const Symbol*> parg(
pos_args.push_back(*((Symbol*)args[i])); // NOLINT(*) (Symbol**)args, (Symbol**)args + num_args); // NOLINT(*)
} s->Compose(parg, kwargs, s_name);
s->Compose(pos_args, {}, s_name);
} else { } else {
std::unordered_map<std::string, Symbol> kwargs;
for (nn_uint i = 0; i < num_args; ++i) { for (nn_uint i = 0; i < num_args; ++i) {
kwargs[keys[i]] = *((Symbol*)args[i]); // NOLINT(*) kwargs[keys[i]] = (Symbol*)args[i]; // NOLINT(*)
} }
s->Compose({}, kwargs, s_name); s->Compose(array_view<const Symbol*>(), kwargs, s_name);
} }
API_END(); API_END();
} }
...@@ -45,7 +45,7 @@ inline void UpdateNodeVersion(Node *n) { ...@@ -45,7 +45,7 @@ inline void UpdateNodeVersion(Node *n) {
CHECK(e.node->is_variable()) CHECK(e.node->is_variable())
<< "Mutation target can only be Variable"; << "Mutation target can only be Variable";
// increase the version of the variable. // increase the version of the variable.
++nnvm::get<VariableParam>(e.node->attrs.parsed).version; e.version = ++nnvm::get<VariableParam>(e.node->attrs.parsed).version;
} }
} }
} }
...@@ -98,7 +98,10 @@ Symbol Symbol::Copy() const { ...@@ -98,7 +98,10 @@ Symbol Symbol::Copy() const {
std::unordered_map<Node*, std::shared_ptr<Node> > old_new; std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
// use DFSVisit to copy all the nodes // use DFSVisit to copy all the nodes
DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) { DFSVisit(this->outputs, [&old_new](const std::shared_ptr<Node>& node) {
old_new[node.get()] = std::make_shared<Node>(*node); std::shared_ptr<Node> np = Node::Create();
np->op = node->op;
np->attrs = node->attrs;
old_new[node.get()] = std::move(np);
}); });
// connect nodes of new graph // connect nodes of new graph
for (const auto &kv : old_new) { for (const auto &kv : old_new) {
...@@ -106,6 +109,9 @@ Symbol Symbol::Copy() const { ...@@ -106,6 +109,9 @@ Symbol Symbol::Copy() const {
Node *ptr = e.node.get(); Node *ptr = e.node.get();
kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version}); kv.second->inputs.emplace_back(NodeEntry{old_new[ptr], e.index, e.version});
} }
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(old_new[p.get()]);
}
} }
// set the head // set the head
Symbol ret; Symbol ret;
...@@ -120,7 +126,7 @@ void Symbol::Print(std::ostream &os) const { ...@@ -120,7 +126,7 @@ void Symbol::Print(std::ostream &os) const {
os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n'; os << "AtomicFunctor "<< " Op:" << outputs[0].node->op->name << '\n';
} else { } else {
// use DFSVisit to copy all the nodes // use DFSVisit to copy all the nodes
os << "Outputs:\n"; os << "Symbol Outputs:\n";
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name os << "\toutput[" << i << "]=" << outputs[i].node->attrs.name
<< '(' << outputs[i].index << ")\n"; << '(' << outputs[i].index << ")\n";
...@@ -129,7 +135,8 @@ void Symbol::Print(std::ostream &os) const { ...@@ -129,7 +135,8 @@ void Symbol::Print(std::ostream &os) const {
if (node->is_variable()) { if (node->is_variable()) {
os << "Variable:" << node->attrs.name << '\n'; os << "Variable:" << node->attrs.name << '\n';
} else { } else {
os << "Name: " << node->attrs.name << " Op:" << node->op->name << '\n' os << "--------------------\n";
os << "Op:" << node->op->name << ", Name=" << node->attrs.name << '\n'
<< "Inputs:\n"; << "Inputs:\n";
for (size_t i = 0; i < node->inputs.size(); ++i) { for (size_t i = 0; i < node->inputs.size(); ++i) {
const NodeEntry& e = node->inputs[i]; const NodeEntry& e = node->inputs[i];
...@@ -141,9 +148,17 @@ void Symbol::Print(std::ostream &os) const { ...@@ -141,9 +148,17 @@ void Symbol::Print(std::ostream &os) const {
os << '\n'; os << '\n';
} }
} }
os << "Attrs:\n"; if (!node->attrs.dict.empty()) {
for (auto &kv : node->attrs.dict) { os << "Attrs:\n";
os << '\t' << kv.first << '=' << kv.second << '\n'; for (auto &kv : node->attrs.dict) {
os << '\t' << kv.first << '=' << kv.second << '\n';
}
}
if (node->control_deps.size() != 0) {
os << "Control deps:\n";
for (size_t i = 0; i < node->control_deps.size(); ++i) {
os << "\tcdep[" << i << "]=" << node->control_deps[i]->attrs.name << '\n';
}
} }
} }
}); });
...@@ -203,8 +218,8 @@ std::vector<std::string> Symbol::ListOutputs() const { ...@@ -203,8 +218,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
} }
// compositional logic // compositional logic
void Symbol::Compose(const std::vector<Symbol>& args, void Symbol::Compose(const array_view<const Symbol*>& args,
const std::unordered_map<std::string, Symbol>& kwargs, const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) { const std::string& name) {
static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames"); static auto& flist_inputs = Op::GetAttr<FListInputNames>("FListInputNames");
...@@ -213,11 +228,11 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -213,11 +228,11 @@ void Symbol::Compose(const std::vector<Symbol>& args,
CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed";
// parameter check. // parameter check.
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
CHECK_EQ(args[i].outputs.size(), 1) CHECK_EQ(args[i]->outputs.size(), 1)
<< "Argument " << i << " is a tuple, single value is required"; << "Argument " << i << " is a tuple, single value is required";
} }
for (const auto& kv : kwargs) { for (const auto& kv : kwargs) {
CHECK_EQ(kv.second.outputs.size(), 1) CHECK_EQ(kv.second->outputs.size(), 1)
<< "Keyword Argument " << kv.first << " is a tuple, single value is required"; << "Keyword Argument " << kv.first << " is a tuple, single value is required";
} }
// assign new name // assign new name
...@@ -234,7 +249,7 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -234,7 +249,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
<< "Incorrect number of arguments, requires " << n_req << "Incorrect number of arguments, requires " << n_req
<< ", provided " << args.size(); << ", provided " << args.size();
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
n->inputs[i] = args[i].outputs[0]; n->inputs[i] = args[i]->outputs[0];
} }
// switch to keyword argument matching // switch to keyword argument matching
if (args.size() != n_req) { if (args.size() != n_req) {
...@@ -247,7 +262,7 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -247,7 +262,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
for (size_t i = args.size(); i < n_req; ++i) { for (size_t i = args.size(); i < n_req; ++i) {
auto it = kwargs.find(arg_names[i]); auto it = kwargs.find(arg_names[i]);
if (it != kwargs.end() && it->first == arg_names[i]) { if (it != kwargs.end() && it->first == arg_names[i]) {
n->inputs[i] = it->second.outputs[0]; n->inputs[i] = it->second->outputs[0];
++nmatched; ++nmatched;
} else { } else {
n->inputs[i] = NodeEntry{ n->inputs[i] = NodeEntry{
...@@ -266,8 +281,8 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -266,8 +281,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
} else { } else {
CHECK_EQ(kwargs.size(), 0) << "Variable length function do not accept kwargs"; CHECK_EQ(kwargs.size(), 0) << "Variable length function do not accept kwargs";
n->inputs.reserve(args.size()); n->inputs.reserve(args.size());
for (const Symbol& s : args) { for (const Symbol* s : args) {
n->inputs.push_back(s.outputs[0]); n->inputs.push_back(s->outputs[0]);
} }
} }
UpdateNodeVersion(n); UpdateNodeVersion(n);
...@@ -283,13 +298,13 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -283,13 +298,13 @@ void Symbol::Compose(const std::vector<Symbol>& args,
(const std::shared_ptr<Node> &node) { (const std::shared_ptr<Node> &node) {
if (node->is_variable()) { if (node->is_variable()) {
if (arg_counter < args.size()) { if (arg_counter < args.size()) {
replace_map[node.get()] = &(args[arg_counter].outputs[0]); replace_map[node.get()] = &(args[arg_counter]->outputs[0]);
++arg_counter; ++arg_counter;
} else { } else {
// match kwargs // match kwargs
auto kit = kwargs.find(node->attrs.name); auto kit = kwargs.find(node->attrs.name);
if (kit != kwargs.end()) { if (kit != kwargs.end()) {
replace_map[node.get()] = &(kit->second.outputs[0]); replace_map[node.get()] = &(kit->second->outputs[0]);
++nmatched; ++nmatched;
} }
} }
...@@ -334,8 +349,8 @@ void Symbol::Compose(const std::vector<Symbol>& args, ...@@ -334,8 +349,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
} }
} }
Symbol Symbol::operator () (const std::vector<Symbol>& args, Symbol Symbol::operator () (const array_view<const Symbol*>& args,
const std::unordered_map<std::string, Symbol>& kwargs, const std::unordered_map<std::string, const Symbol*>& kwargs,
const std::string& name) const { const std::string& name) const {
Symbol s = this->Copy(); Symbol s = this->Copy();
s.Compose(args, kwargs, name); s.Compose(args, kwargs, name);
......
/*!
* Copyright (c) 2016 by Contributors
* \file saveload_json.cc
* \brief Add control flow dependencies between nodes
* To correctly order mutation and read to resolve
* write after read problem and read after write problems.
*/
#include <nnvm/pass.h>
#include <nnvm/op_attr_types.h>
namespace nnvm {
template<typename T>
inline T get_with_default(const std::unordered_map<Node*, T> &map,
Node* key,
const T& def) {
auto it = map.find(key);
if (it != map.end()) return it->second;
return def;
}
Graph OrderMutation(const Graph& src) {
std::unordered_map<Node*, std::vector<NodeEntry> > version_hist;
DFSVisit(src.outputs, [&version_hist](const std::shared_ptr<Node>& n) {
for (const NodeEntry& e : n->inputs) {
if (e.node->is_variable()) {
if (e.version != 0 && version_hist.count(e.node.get()) == 0) {
version_hist[e.node.get()] = std::vector<NodeEntry>{};
}
}
}
});
// no mutation happens, everything if fine.
if (version_hist.size() == 0) return src;
// start preparing for remapping the nodes.
std::unordered_map<Node*, std::shared_ptr<Node> > old_new;
auto prepare = [&version_hist, &old_new] (const std::shared_ptr<Node>& n) {
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
bool need_repl = false;
for (size_t i = 0; i < n->inputs.size(); ++i) {
const NodeEntry& e = n->inputs[i];
if (e.node->is_variable()) {
if (e.version != 0) need_repl = true;
auto it = version_hist.find(e.node.get());
if (it != version_hist.end()) {
std::vector<NodeEntry>& vec = it->second;
uint32_t is_mutate =
fmutate_inputs.count(n->op) ? fmutate_inputs[n->op](n->attrs, i) : 0;
vec.emplace_back(NodeEntry{n, is_mutate, e.version});
}
} else {
if (old_new.count(e.node.get()) != 0) need_repl = true;
}
}
for (const std::shared_ptr<Node>& p : n->control_deps) {
if (old_new.count(p.get()) != 0) need_repl = true;
}
if (need_repl) {
std::shared_ptr<Node> np = Node::Create();
np->op = n->op;
np->attrs = n->attrs;
old_new[n.get()] = std::move(np);
}
};
DFSVisit(src.outputs, prepare);
// comparator of history entry
auto comparator = [](const NodeEntry& a, const NodeEntry &b) {
if (a.version < b.version) return true;
if (a.version > b.version) return false;
return a.index > b.index;
};
for (auto &kv : version_hist) {
std::sort(kv.second.begin(), kv.second.end(), comparator);
}
// copy the nodes, as well as add control deps
for (auto &kv : old_new) {
// copy the nodes
for (const NodeEntry& e : kv.first->inputs) {
auto it = old_new.find(e.node.get());
if (it != old_new.end()) {
kv.second->inputs.emplace_back(NodeEntry{it->second, e.index, e.version});
} else {
kv.second->inputs.push_back(e);
}
}
for (const std::shared_ptr<Node>& p : kv.first->control_deps) {
kv.second->control_deps.emplace_back(
get_with_default(old_new, p.get(), p));
}
// add control deps
static auto& fmutate_inputs = Op::GetAttr<FMutateInput>("FMutateInput");
for (size_t i = 0; i < kv.first->inputs.size(); ++i) {
const NodeEntry& e = kv.first->inputs[i];
if (e.node->is_variable() && version_hist.count(e.node.get()) != 0) {
FMutateInput fmutate = fmutate_inputs.get(kv.first->op, nullptr);
uint32_t is_mutate = (fmutate == nullptr) ? 0 : fmutate(kv.first->attrs, i);
std::vector<NodeEntry>& vec = version_hist.at(e.node.get());
auto it = std::lower_bound(vec.begin(), vec.end(),
NodeEntry{nullptr, 1, e.version},
comparator);
if (is_mutate != 0) {
int read_dep = 0;
while (it != vec.begin()) {
--it;
if (it->index != 0) break;
++read_dep;
// depend on previous read
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
if (read_dep == 0 && it->index != 0) {
// depend on last write
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
} else {
// depend on last write
if (it->index != 0) {
kv.second->control_deps.push_back(
get_with_default(old_new, it->node.get(), it->node));
}
}
}
}
}
Graph ret;
for (const NodeEntry &e : src.outputs) {
ret.outputs.emplace_back(NodeEntry{
get_with_default(old_new, e.node.get(), e.node), e.index, e.version});
}
return ret;
}
NNVM_REGISTER_PASS(OrderMutation)
.describe("Return a new graph that adds control dependencies, "\
"to order the mutation and reads if mutation exists.")
.set_body(OrderMutation)
.set_change_graph(true);
} // namespace nnvm
/*! /*!
* Copyright (c) 2016 by Contributors * Copyright (c) 2016 by Contributors
* \file saveload_json.cc * \file saveload_json.cc
* \brief Passes that defines save and load graph to/from JSON file. * \brief Save and load graph to/from JSON file.
*/ */
#include <nnvm/pass.h> #include <nnvm/pass.h>
#include <dmlc/json.h> #include <dmlc/json.h>
......
...@@ -2,42 +2,85 @@ ...@@ -2,42 +2,85 @@
#include <nnvm/op.h> #include <nnvm/op.h>
#include <nnvm/graph.h> #include <nnvm/graph.h>
#include <nnvm/tuple.h> #include <nnvm/tuple.h>
#include <nnvm/c_api.h>
#include <nnvm/graph_attr_types.h> #include <nnvm/graph_attr_types.h>
#include <dmlc/timer.h>
#include <string> #include <string>
void test_op() { void test_speed() {
using namespace nnvm; auto add = nnvm::Op::Get("add");
auto add = Op::Get("add"); double tstart = dmlc::GetTime();
static auto& nick = Op::GetAttr<std::string>("nick_name"); size_t rep = 1000;
LOG(INFO) << "nick=" << nick[add]; size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx";
for (size_t t = 0; t < rep; ++t) {
nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
for (size_t i = 0; i < n; ++i) {
nnvm::Symbol nw = nnvm::Symbol::CreateFunctor(add, {});
vec[0] = &s;
vec[1] =&s;
tmp.clear();
nw.Compose(vec, tmp, name);
s = nw;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "compose speed = " << n * rep / (tend - tstart) << " ops/sec";
} }
void test_tuple() { void test_node_speed() {
using nnvm::Tuple; using namespace nnvm;
using nnvm::TShape; auto add = nnvm::Op::Get("add");
Tuple<int> x{1, 2, 3}; double tstart = dmlc::GetTime();
Tuple<int> y{1, 2, 3, 5, 6}; size_t rep = 1000;
x = std::move(y); size_t n = 100;
for (size_t t = 0; t < rep; ++t) {
CHECK_EQ(x.ndim(), 5); nnvm::Symbol s = nnvm::Symbol::CreateVariable("x");
Tuple<int> z{1, 2, 3, 5, 6}; for (size_t i = 0; i < n; ++i) {
std::ostringstream os; auto xx = NodeEntry{Node::Create(), 0, 0};
os << z; NodeEntry x = s.outputs[0];
CHECK_EQ(os.str(), "(1,2,3,5,6)"); xx.node->op = add;
std::istringstream is(os.str()); xx.node->inputs.emplace_back(x);
is >> y; xx.node->inputs.emplace_back(x);
CHECK_EQ(x, y); Symbol ss;
Tuple<nnvm::index_t> ss{1, 2, 3}; ss.outputs.push_back(xx);
TShape s = ss; s = ss;
s = std::move(ss); }
CHECK((s == TShape{1, 2, 3})); }
double tend = dmlc::GetTime();
LOG(INFO) << "test_node_speed speed = " << n * rep / (tend - tstart) << " ops/sec";
} }
void test_api_speed() {
void test_graph() { auto add = (void*)nnvm::Op::Get("add"); // NOLINT(*)
nnvm::Symbol s; double tstart = dmlc::GetTime();
size_t rep = 1000;
size_t n = 1000;
std::unordered_map<std::string, const nnvm::Symbol*> tmp;
std::vector<const nnvm::Symbol*> vec{2};
std::string name = "xx";
for (size_t t = 0; t < rep; ++t) {
SymbolHandle s;
NNSymbolCreateVariable("xx", &s);
for (size_t i = 0; i < n; ++i) {
SymbolHandle arg[2];
SymbolHandle ss;
NNSymbolCreateAtomicSymbol(add, 0, nullptr, nullptr, &ss);
arg[0] = s;
arg[1] = s;
NNSymbolCompose(ss, "nn", 2, nullptr, arg);
s = ss;
}
}
double tend = dmlc::GetTime();
LOG(INFO) << "API compose speed = " << n * rep / (tend - tstart) << " ops/sec";
} }
int main() { int main() {
test_tuple(); test_speed();
test_node_speed();
test_api_speed();
return 0; return 0;
} }
import json
import nnvm.symbol as sym import nnvm.symbol as sym
import nnvm.graph as graph import nnvm.graph as graph
...@@ -17,7 +18,24 @@ def test_graph_json_attr(): ...@@ -17,7 +18,24 @@ def test_graph_json_attr():
g._set_json_attr('ilist', [1,2,3], 'list_int') g._set_json_attr('ilist', [1,2,3], 'list_int')
assert g.json_attr('ilist') == [1,2,3] assert g.json_attr('ilist') == [1,2,3]
def test_order_mutation_pass():
x = sym.Variable('x')
y = sym.conv2d(data=x, name='conv', dev='gpu')
y = sym.add(y, x, name='add1')
# write after read
z = sym.assign(x, y, name='assign')
# read after write
t = sym.add(y, x, name='add2')
g = graph.create(sym.Group([t, z]))
jgraph = json.loads(g.apply(['OrderMutation', 'SaveJSON']).json_attr('json'))
jnodes = jgraph['nodes']
nindex = {n['name']: i for i, n in enumerate(jnodes)}
assert nindex['assign'] in jnodes[nindex['add2']]['control_deps']
assert nindex['conv'] in jnodes[nindex['assign']]['control_deps']
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
if __name__ == "__main__": if __name__ == "__main__":
test_order_mutation_pass()
test_graph_json_attr() test_graph_json_attr()
test_json_pass() test_json_pass()
...@@ -36,8 +36,16 @@ def test_mutate_input(): ...@@ -36,8 +36,16 @@ def test_mutate_input():
except NNVMError: except NNVMError:
pass pass
def test_copy():
x = sym.Variable('x')
z = sym.Variable('z')
y = sym.exp(sym.add(x, x, name='add', gpu=2),
name='exp', gpu=1, attr={"kk": "1"})
assert y.__copy__().debug_str() == y.debug_str()
if __name__ == "__main__": if __name__ == "__main__":
test_copy()
test_default_input() test_default_input()
test_compose() test_compose()
test_mutate_input() test_mutate_input()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment