saveload_json.cc 9.75 KB
Newer Older
1 2 3
/*!
 *  Copyright (c) 2016 by Contributors
 * \file saveload_json.cc
4
 * \brief Save and load graph to/from JSON file.
5 6
 */
#include <nnvm/pass.h>
7
#include <nnvm/pass_functions.h>
8 9 10 11 12 13 14
#include <dmlc/json.h>
#include <algorithm>

namespace dmlc {
namespace json {
// overload handler for shared ptr
template<>
15 16
struct Handler<std::shared_ptr<any> > {
  inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) {
17 18
    writer->Write(*data);
  }
19
  inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) {
20 21 22 23 24 25 26 27 28 29
    any v;
    reader->Read(&v);
    *data = std::make_shared<any>(std::move(v));
  }
};
}  // namespace json
}  // namespace dmlc

namespace nnvm {
namespace pass {
30
namespace {
31

32 33 34 35 36
// JSONNode represents an nnvm::Node in JSON
struct JSONNode;
// JSONGraph represents an nnvm::Graph or nnvm::Symbol in JSON
struct JSONGraph;

37 38 39
// auxiliary node structure for serialization.
struct JSONNode {
  // the node entry structure in serialized format
Tianqi Chen committed
40 41 42 43
  struct Entry {
    uint32_t node_id;
    uint32_t index;
    uint32_t version;
44 45 46 47
    Entry() = default;
    Entry(uint32_t node_id, uint32_t index, uint32_t version):
      node_id(node_id), index(index), version(version) {
    }
Tianqi Chen committed
48
    void Save(dmlc::JSONWriter *writer) const {
49
      writer->BeginArray(false);
Tianqi Chen committed
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
      writer->WriteArrayItem(node_id);
      writer->WriteArrayItem(index);
      writer->WriteArrayItem(version);
      writer->EndArray();
    }
    void Load(dmlc::JSONReader *reader) {
      reader->BeginArray();
      CHECK(reader->NextArrayItem()) << "invalid json format";
      reader->Read(&node_id);
      CHECK(reader->NextArrayItem()) << "invalid json format";
      reader->Read(&index);
      if (reader->NextArrayItem()) {
        reader->Read(&version);
        CHECK(!reader->NextArrayItem()) << "invalid json format";
      } else {
        version = 0;
      }
    }
  };

70
  // pointer to the graph node
71
  NodePtr node;
72 73 74 75
  // inputs
  std::vector<Entry> inputs;
  // control flow dependencies
  std::vector<uint32_t> control_deps;
76 77
  // subgraphs
  std::vector<JSONGraph> subgraphs;
78 79 80 81

  // function to save JSON node.
  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
82 83
    if (node->op() != nullptr) {
      writer->WriteObjectKeyValue("op", node->op()->name);
84 85 86 87 88
    } else {
      std::string json_null = "null";
      writer->WriteObjectKeyValue("op", json_null);
    }
    writer->WriteObjectKeyValue("name", node->attrs.name);
89
    if (node->attrs.dict.size() != 0) {
90 91 92
      // write attributes in order;
      std::map<std::string, std::string> dict(
          node->attrs.dict.begin(), node->attrs.dict.end());
93
      writer->WriteObjectKeyValue("attrs", dict);
94
    }
95
    writer->WriteObjectKeyValue("inputs", inputs);
96 97 98
    if (control_deps.size() != 0) {
      writer->WriteObjectKeyValue("control_deps", control_deps);
    }
99 100 101
    if (subgraphs.size() != 0) {
      writer->WriteObjectKeyValue("subgraphs", subgraphs);
    }
102 103 104 105
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
Tianqi Chen committed
106
    node = Node::Create();
107 108 109 110 111 112
    control_deps.clear();
    dmlc::JSONObjectReadHelper helper;
    std::string op_type_str;
    helper.DeclareField("op", &op_type_str);
    helper.DeclareField("name", &(node->attrs.name));
    helper.DeclareField("inputs", &inputs);
113
    helper.DeclareOptionalField("attrs", &(node->attrs.dict));
114 115
    helper.DeclareOptionalField("attr", &(node->attrs.dict));
    helper.DeclareOptionalField("control_deps", &control_deps);
116
    helper.DeclareOptionalField("subgraphs", &subgraphs);
117 118 119 120 121 122
    // backward compatible code with mxnet graph.
    int backward_source_id;
    std::unordered_map<std::string, std::string> param;
    helper.DeclareOptionalField("param", &param);
    helper.DeclareOptionalField("backward_source_id", &backward_source_id);
    helper.ReadAllFields(reader);
123
    node->attrs.dict.insert(param.begin(), param.end());
124 125 126

    if (op_type_str != "null") {
      try {
127
        node->attrs.op = Op::Get(op_type_str);
128 129 130 131 132 133 134
      } catch (const dmlc::Error &err) {
        std::ostringstream os;
        os << "Failed loading Op " << node->attrs.name
           << " of type " << op_type_str << ": " << err.what();
        throw dmlc::Error(os.str());
      }
    } else {
135
      node->attrs.op = nullptr;
136 137 138 139 140 141 142 143
    }
  }
};

// graph structure to help read/save JSON.
struct JSONGraph {
  std::vector<JSONNode> nodes;
  std::vector<uint32_t> arg_nodes;
144
  std::vector<uint32_t> node_row_ptr;
145
  std::vector<JSONNode::Entry> heads;
146
  std::unordered_map<std::string, std::shared_ptr<any> > attrs;
147 148 149 150 151

  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
    writer->WriteObjectKeyValue("nodes", nodes);
    writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
152
    writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
153 154 155 156 157 158 159 160 161 162 163 164 165
    writer->WriteObjectKeyValue("heads", heads);
    if (attrs.size() != 0) {
      writer->WriteObjectKeyValue("attrs", attrs);
    }
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
    attrs.clear();
    dmlc::JSONObjectReadHelper helper;
    helper.DeclareField("nodes", &nodes);
    helper.DeclareField("arg_nodes", &arg_nodes);
    helper.DeclareField("heads", &heads);
166
    helper.DeclareOptionalField("node_row_ptr", &node_row_ptr);
167 168 169 170 171
    helper.DeclareOptionalField("attrs", &attrs);
    helper.ReadAllFields(reader);
  }
};

172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
  std::unordered_map<Node*, uint32_t> node2index;
  jgraph->node_row_ptr.push_back(0);
  DFSVisit(src->outputs, [&node2index, jgraph](const NodePtr& n) {
    uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
    node2index[n.get()] = nid;
    if (n->is_variable()) {
      jgraph->arg_nodes.push_back(nid);
    }
    JSONNode jnode;
    jnode.node = n;
    jnode.inputs.reserve(n->inputs.size());
    for (const NodeEntry& e : n->inputs) {
      jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
    }
    for (const NodePtr& c : n->control_deps) {
      jnode.control_deps.push_back(node2index.at(c.get()));
    }
    jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
    jgraph->nodes.emplace_back(std::move(jnode));
  });
  for (const NodeEntry& e : src->outputs) {
    jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version);
195
  }
196 197 198 199 200 201 202 203 204 205 206 207 208 209
  // recursively construct subgraphs
  for (JSONNode &jnode : jgraph->nodes) {
    // construct jnode's subgraphs
    const std::vector<std::shared_ptr<Symbol>> &subgraphs = jnode.node->attrs.subgraphs;
    std::vector<JSONGraph> &jsubgraphs = jnode.subgraphs;
    jsubgraphs.resize(subgraphs.size());
    for (uint32_t i = 0; i < subgraphs.size(); ++i) {
      Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]);
    }
  }
}

std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) {
  for (const JSONNode &n : jgraph.nodes) {
210 211
    n.node->inputs.reserve(n.inputs.size());
    for (const JSONNode::Entry &e : n.inputs) {
212
      n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
213 214 215 216 217
    }
    n.node->control_deps.reserve(n.control_deps.size());
    for (uint32_t nid : n.control_deps) {
      n.node->control_deps.push_back(jgraph.nodes[nid].node);
    }
218
    // rebuild attribute parser
219
    if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) {
220 221
      n.node->op()->attr_parser(&(n.node->attrs));
    }
222 223 224 225 226 227 228 229
    for (const JSONGraph &subgraph : n.subgraphs) {
      // The "no_parse" option here, is to be compatible with
      // commit cfd3075e85807dcd8f9534c37e053583dee87524
      // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524),
      // where the parsing of main graph is deferred until
      // incubator-mxnet/src/nnvm/legacy_json_util.cc:UpgradeJSON_Parse
      n.node->attrs.subgraphs.push_back(JSONGraph2Symbol(subgraph, false));
    }
230
  }
231
  // consistency check
232 233 234
  for (uint32_t nid : jgraph.arg_nodes) {
    CHECK(jgraph.nodes[nid].node->is_variable());
  }
235 236 237 238 239 240 241
  std::shared_ptr<Symbol> symbol = std::make_shared<Symbol>();
  symbol->outputs.reserve(jgraph.heads.size());
  for (const JSONNode::Entry &e : jgraph.heads) {
    symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
  }
  return symbol;
}
242

243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
// Load a graph from JSON file.
Graph LoadJSON(Graph src) {
  CHECK_NE(src.attrs.count("json"), 0U)
      << "Load JSON require json to be presented.";
  const std::string &json_str =
      nnvm::get<std::string>(*src.attrs.at("json"));
  bool no_parse = false;
  if (src.attrs.count("load_json_no_parse")) {
    no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
  }
  std::istringstream is(json_str);
  dmlc::JSONReader reader(&is);
  JSONGraph jgraph;
  // load in json graph.
  jgraph.Load(&reader);
  std::shared_ptr<Symbol> symbol = JSONGraph2Symbol(jgraph, no_parse);
259 260 261
  // return the graph
  Graph ret;
  ret.attrs = std::move(jgraph.attrs);
262
  ret.outputs = symbol->outputs;
263 264 265 266
  return ret;
}

// save a graph to json
267
Graph SaveJSON(Graph src) {
268 269
  std::shared_ptr<Symbol> src_symbol = std::make_shared<Symbol>();
  src_symbol->outputs = src.outputs;
270
  JSONGraph jgraph;
271
  Symbol2JSONGraph(src_symbol, &jgraph);
272
  jgraph.attrs = src.attrs;
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
  std::ostringstream os;
  dmlc::JSONWriter writer(&os);
  jgraph.Save(&writer);
  Graph ret;
  ret.attrs["json"] = std::make_shared<any>(os.str());
  return ret;
}

// register pass
NNVM_REGISTER_PASS(LoadJSON)
.describe("Return a new Graph, loaded from src.attrs[\"json\"]")
.set_body(LoadJSON)
.set_change_graph(true)
.depend_graph_attr("json");

NNVM_REGISTER_PASS(SaveJSON)
.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]")
.set_body(SaveJSON)
.set_change_graph(true)
.provide_graph_attr("json");

294 295 296

DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
297
DMLC_JSON_ENABLE_ANY(std::vector<std::string>, list_str);
298

299
}  // namespace
300 301
}  // namespace pass
}  // namespace nnvm