Commit 94ae677a by Tianqi Chen

Updates (#14)

* Remove outstanding cython functions

* Add in operator overload

* Enable JSON to save version
parent c1e48e1a
......@@ -14,10 +14,6 @@ include "./base.pyi"
cdef extern from "nnvm/c_api.h":
const char* NNGetLastError();
int NNSymbolCreateVariable(const char *name, SymbolHandle *out);
int NNSymbolCreateGroup(nn_uint num_symbols,
SymbolHandle *symbols,
SymbolHandle *out);
int NNSymbolListAtomicSymbolCreators(nn_uint *out_size,
AtomicSymbolCreator **out_array);
int NNSymbolCreateAtomicSymbol(AtomicSymbolCreator creator,
......@@ -34,31 +30,10 @@ cdef extern from "nnvm/c_api.h":
const char ***arg_descriptions,
const char **return_type);
int NNSymbolFree(SymbolHandle symbol);
int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
int NNSymbolGetAttr(SymbolHandle symbol,
const char* key,
const char** out,
int *success);
int NNSymbolSetAttrs(SymbolHandle symbol,
nn_uint num_param,
const char** keys,
const char** values);
int NNSymbolListAttrs(SymbolHandle symbol,
int recursive_option,
nn_uint *out_size,
const char*** out);
int NNSymbolListArguments(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
int NNSymbolListOutputs(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
int NNSymbolGetOutput(SymbolHandle symbol,
nn_uint index,
SymbolHandle *out);
int NNSymbolCompose(SymbolHandle sym,
const char* name,
nn_uint num_args,
......
......@@ -30,12 +30,71 @@ class Symbol(SymbolBase):
def __add__(self, other):
if isinstance(other, Symbol):
return _internal.__add__symbol__(self, other)
return _internal.__add_symbol__(self, other)
elif isinstance(other, _Number):
return _internal.__add__scalar__(self, scalar=other)
return _internal.__add_scalar__(self, scalar=other)
else:
raise TypeError("type %s not supported" % str(type(other)))
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
if isinstance(other, Symbol):
return _internal.__sub_symbol__(self, other)
if isinstance(other, Number):
return _internal.__sub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rsub__(self, other):
if isinstance(other, Number):
return _internal.__rsub_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __mul__(self, other):
if isinstance(other, Symbol):
return _internal.__mul_symbol__(self, other)
if isinstance(other, Number):
return _internal.__mul_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rmul__(self, other):
return self.__mul__(other)
def __div__(self, other):
if isinstance(other, Symbol):
return _internal.__div_symbol__(self, other)
if isinstance(other, Number):
return _internal.__div_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __rdiv__(self, other):
if isinstance(other, Number):
return _internal.__rdiv_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __truediv__(self, other):
return self.__div__(other)
def __rtruediv__(self, other):
return self.__rdiv__(other)
def __pow__(self, other):
if isinstance(other, Symbol):
return _internal.__pow_symbol__(self, other)
if isinstance(other, Number):
return _internal.__pow_scalar__(self, scalar=other)
else:
raise TypeError('type %s not supported' % str(type(other)))
def __neg__(self):
return self.__mul__(-1.0)
def __copy__(self):
return self.__deepcopy__()
......
......@@ -11,9 +11,11 @@ using nnvm::NodeAttrs;
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr("inplace_pair", std::make_pair(0, 0));
.set_num_inputs(2);
NNVM_REGISTER_OP(__add_symbol__)
.describe("Alias of add")
.set_num_inputs(2);
NNVM_REGISTER_OP(exp)
.describe("take exponmential")
......
......@@ -30,7 +30,32 @@ namespace pass {
// auxiliary node structure for serialization.
struct JSONNode {
// the node entry structure in serialized format
typedef std::pair<uint32_t, uint32_t> Entry;
struct Entry {
uint32_t node_id;
uint32_t index;
uint32_t version;
void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray();
writer->WriteArrayItem(node_id);
writer->WriteArrayItem(index);
writer->WriteArrayItem(version);
writer->EndArray();
}
void Load(dmlc::JSONReader *reader) {
reader->BeginArray();
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&node_id);
CHECK(reader->NextArrayItem()) << "invalid json format";
reader->Read(&index);
if (reader->NextArrayItem()) {
reader->Read(&version);
CHECK(!reader->NextArrayItem()) << "invalid json format";
} else {
version = 0;
}
}
};
// pointer to the graph node
NodePtr node;
// inputs
......@@ -75,6 +100,10 @@ struct JSONNode {
if (op_type_str != "null") {
try {
node->op = Op::Get(op_type_str);
// rebuild attribute parser
if (node->op->attr_parser != nullptr) {
node->op->attr_parser(&(node->attrs));
}
} catch (const dmlc::Error &err) {
std::ostringstream os;
os << "Failed loading Op " << node->attrs.name
......@@ -132,7 +161,7 @@ Graph LoadJSON(const Graph& src) {
n.node->inputs.reserve(n.inputs.size());
for (const JSONNode::Entry &e : n.inputs) {
n.node->inputs.emplace_back(
NodeEntry{jgraph.nodes[e.first].node, e.second});
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
n.node->control_deps.reserve(n.control_deps.size());
for (uint32_t nid : n.control_deps) {
......@@ -150,7 +179,7 @@ Graph LoadJSON(const Graph& src) {
ret.outputs.reserve(jgraph.heads.size());
for (const JSONNode::Entry &e : jgraph.heads) {
ret.outputs.emplace_back(
NodeEntry{jgraph.nodes[e.first].node, e.second});
NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
}
return ret;
}
......@@ -170,7 +199,7 @@ Graph SaveJSON(const Graph& src) {
jnode.inputs.reserve(n->inputs.size());
for (const NodeEntry& e : n->inputs) {
jnode.inputs.emplace_back(
std::make_pair(node2index.at(e.node.get()), e.index));
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}
for (const NodePtr& c : n->control_deps) {
jnode.control_deps.push_back(node2index.at(c.get()));
......@@ -179,7 +208,8 @@ Graph SaveJSON(const Graph& src) {
});
for (const NodeEntry& e : src.outputs) {
jgraph.heads.push_back(std::make_pair(node2index.at(e.node.get()), e.index));
jgraph.heads.push_back(
JSONNode::Entry{node2index.at(e.node.get()), e.index, e.version});
}
std::ostringstream os;
......
......@@ -33,7 +33,7 @@ def test_order_mutation_pass():
assert nindex['assign'] in jnodes[nindex['add2']]['control_deps']
assert nindex['conv'] in jnodes[nindex['assign']]['control_deps']
assert nindex['add1'] in jnodes[nindex['assign']]['control_deps']
assert jnodes[nindex['assign']]['inputs'][0][2] == 1
if __name__ == "__main__":
test_order_mutation_pass()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment