Commit 94ae677a by Tianqi Chen

Updates (#14)

* Remove outstanding cython functions

* Add in operator overload

* Enable JSON to save version
parent c1e48e1a
...@@ -14,10 +14,6 @@ include "./base.pyi" ...@@ -14,10 +14,6 @@ include "./base.pyi"
cdef extern from "nnvm/c_api.h": cdef extern from "nnvm/c_api.h":
const char* NNGetLastError(); const char* NNGetLastError();
int NNSymbolCreateVariable(const char *name, SymbolHandle *out);
int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size, int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array); AtomicSymbolCreator **out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator, int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
...@@ -34,31 +30,10 @@ cdef extern from "nnvm/c_api.h": ...@@ -34,31 +30,10 @@ cdef extern from "nnvm/c_api.h":
const char ***arg_descriptions, const char ***arg_descriptions,
const char **return_type); const char **return_type);
int NNSymbolFree(SymbolHandle symbol); int NNSymbolFree(SymbolHandle symbol);
int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
int NNSymbolSetAttrs(SymbolHandle symbol, int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param, nn_uint num_param,
const char** keys, const char** keys,
const char** values); const char** values);
int NNSymbolListAttrs(SymbolHandle symbol,
int recursive_option,
nn_uint *out_size,
const char*** out);
int NNSymbolListArguments(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
int NNSymbolListOutputs(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
int NNSymbolGetOutput(SymbolHandle symbol,
nn_uint index,
SymbolHandle *out);
int NNSymbolCompose(SymbolHandle sym, int NNSymbolCompose(SymbolHandle sym,
const char* name, const char* name,
nn_uint num_args, nn_uint num_args,
......
...@@ -30,12 +30,71 @@ class Symbol(SymbolBase): ...@@ -30,12 +30,71 @@ class Symbol(SymbolBase):
def __add__(self, other): def __add__(self, other):
if isinstance(other, Symbol): if isinstance(other, Symbol):
return _internal.__add__symbol__(self, other) return _internal.__add_symbol__(self, other)
elif isinstance(other, _Number): elif isinstance(other, _Number):
return _internal.__add__scalar__(self, scalar=other) return _internal.__add_scalar__(self, scalar=other)
else: else:
raise TypeError("type %s not supported" % str(type(other))) raise TypeError("type %s not supported" % str(type(other)))
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, Symbol):
return _internal.__sub_symbol__(self, other)
if isinstance(other, Number):
return _internal.__sub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rsub__(self, other):
if isinstance(other, Number):
return _internal.__rsub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __mul__(self, other):
if isinstance(other, Symbol):
return _internal.__mul_symbol__(self, other)
if isinstance(other, Number):
return _internal.__mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other):
return self.__mul__(other)
def __div__(self, other):
if isinstance(other, Symbol):
return _internal.__div_symbol__(self, other)
if isinstance(other, Number):
return _internal.__div_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rdiv__(self, other):
if isinstance(other, Number):
return _internal.__rdiv_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
return self.__div__(other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __pow__(self, other):
if isinstance(other, Symbol):
return _internal.__pow_symbol__(self, other)
if isinstance(other, Number):
return _internal.__pow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __neg__(self):
return self.__mul__(-1.0)
def __copy__(self): def __copy__(self):
return self.__deepcopy__() return self.__deepcopy__()
......
...@@ -11,9 +11,11 @@ using nnvm::NodeAttrs; ...@@ -11,9 +11,11 @@ using nnvm::NodeAttrs;
NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(add)
.describe("add two data together") .describe("add two data together")
.set_num_inputs(2) .set_num_inputs(2);
.attr("inplace_pair", std::make_pair(0, 0));
NNVM_REGISTER_OP(__add_symbol__)
.describe("Alias of add")
.set_num_inputs(2);
NNVM_REGISTER_OP(exp) NNVM_REGISTER_OP(exp)
.describe("take exponmential") .describe("take exponmential")
......
...@@ -30,7 +30,32 @@ namespace pass { ...@@ -30,7 +30,32 @@ namespace pass {
// auxiliary node structure for serialization. // auxiliary node structure for serialization.
struct JSONNode { struct JSONNode {
// the node entry structure in serialized format // the node entry structure in serialized format
typedef std::pair<uint32_t, uint32_t> Entry; struct Entry {
uint32_t node_id;
uint32_t index;
uint32_t version;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray();
writer->WriteArrayItem(node_id);
writer->WriteArrayItem(index);
writer->WriteArrayItem(version);
writer->EndArray();
}
void Load(dmlc::JSONReader *reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&index);
if (reader->NextArrayItem()) {
reader->Read(&version);
CHECK(!reader->NextArrayItem()) << "invalid json format";
} else {
version = 0;
}
}
};
// pointer to the graph node // pointer to the graph node
NodePtr node; NodePtr node;
// inputs // inputs
...@@ -75,6 +100,10 @@ struct JSONNode { ...@@ -75,6 +100,10 @@ struct JSONNode {
if (op_type_str != "null") { if (op_type_str != "null") {
try { try {
node->op = Op::Get(op_type_str); node->op = Op::Get(op_type_str);
// rebuild attribute parser
if (node->op->attr_parser != nullptr) {
node->op->attr_parser(&(node->attrs));
}
} catch (const dmlc::Error &err) { } catch (const dmlc::Error &err) {
std::ostringstream os; std::ostringstream os;
os << "Failed loading Op " << node->attrs.name os << "Failed loading Op " << node->attrs.name
...@@ -132,7 +161,7 @@ Graph LoadJSON(const Graph& src) { ...@@ -132,7 +161,7 @@ Graph LoadJSON(const Graph& src) {
n.node->inputs.reserve(n.inputs.size()); n.node->inputs.reserve(n.inputs.size());
for (const JSONNode::Entry &e : n.inputs) { for (const JSONNode::Entry &e : n.inputs) {
n.node->inputs.emplace_back( n.node->inputs.emplace_back(
NodeEntry{jgraph.nodes[e.first].node, e.second}); NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
} }
n.node->control_deps.reserve(n.control_deps.size()); n.node->control_deps.reserve(n.control_deps.size());
for (uint32_t nid : n.control_deps) { for (uint32_t nid : n.control_deps) {
...@@ -150,7 +179,7 @@ Graph LoadJSON(const Graph& src) { ...@@ -150,7 +179,7 @@ Graph LoadJSON(const Graph& src) {
ret.outputs.reserve(jgraph.heads.size()); ret.outputs.reserve(jgraph.heads.size());
for (const JSONNode::Entry &e : jgraph.heads) { for (const JSONNode::Entry &e : jgraph.heads) {
ret.outputs.emplace_back( ret.outputs.emplace_back(
NodeEntry{jgraph.nodes[e.first].node, e.second}); NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
} }
return ret; return ret;
} }
...@@ -170,7 +199,7 @@ Graph SaveJSON(const Graph& src) { ...@@ -170,7 +199,7 @@ Graph SaveJSON(const Graph& src) {
jnode.inputs.reserve(n->inputs.size()); jnode.inputs.reserve(n->inputs.size());
for (const NodeEntry& e : n->inputs) { for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back( jnode.inputs.emplace_back(
std::make_pair(node2index.at(e.node.get()), e.index)); JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
} }
for (const NodePtr& c : n->control_deps) { for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get())); jnode.control_deps.push_back(node2index.at(c.get()));
...@@ -179,7 +208,8 @@ Graph SaveJSON(const Graph& src) { ...@@ -179,7 +208,8 @@ Graph SaveJSON(const Graph& src) {
}); });
for (const NodeEntry& e : src.outputs) { for (const NodeEntry& e : src.outputs) {
jgraph.heads.push_back(std::make_pair(node2index.at(e.node.get()), e.index)); jgraph.heads.push_back(
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
} }
std::ostringstream os; std::ostringstream os;
......
...@@ -33,7 +33,7 @@ def test_order_mutation_pass(): ...@@ -33,7 +33,7 @@ def test_order_mutation_pass():
assert nindex['assign'] in jnodes[nindex['add2']]['control_deps'] assert nindex['assign'] in jnodes[nindex['add2']]['control_deps']
assert nindex['conv'] in jnodes[nindex['assign']]['control_deps'] assert nindex['conv'] in jnodes[nindex['assign']]['control_deps']
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps'] assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
assert jnodes[nindex['assign']]['inputs'][0][2] == 1
if __name__ == "__main__": if __name__ == "__main__":
test_order_mutation_pass() test_order_mutation_pass()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment