saveload_json.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20 21
/*!
 * \file saveload_json.cc
22
 * \brief Save and load graph to/from JSON file.
23 24
 */
#include <nnvm/pass.h>
25
#include <nnvm/pass_functions.h>
26 27 28 29 30 31 32
#include <dmlc/json.h>
#include <algorithm>

namespace dmlc {
namespace json {
// overload handler for shared ptr
template<>
33 34
struct Handler<std::shared_ptr<any> > {
  inline static void Write(JSONWriter *writer, const std::shared_ptr<any> &data) {
35 36
    writer->Write(*data);
  }
37
  inline static void Read(JSONReader *reader, std::shared_ptr<any> *data) {
38 39 40 41 42 43 44 45 46 47
    any v;
    reader->Read(&v);
    *data = std::make_shared<any>(std::move(v));
  }
};
}  // namespace json
}  // namespace dmlc

namespace nnvm {
namespace pass {
48
namespace {
49

50 51 52 53 54
// JSONNode represents an nnvm::Node in JSON
struct JSONNode;
// JSONGraph represents an nnvm::Graph or nnvm::Symbol in JSON
struct JSONGraph;

55 56 57
// auxiliary node structure for serialization.
struct JSONNode {
  // the node entry structure in serialized format
Tianqi Chen committed
58 59 60 61
  struct Entry {
    uint32_t node_id;
    uint32_t index;
    uint32_t version;
62 63 64 65
    Entry() = default;
    Entry(uint32_t node_id, uint32_t index, uint32_t version):
      node_id(node_id), index(index), version(version) {
    }
Tianqi Chen committed
66
    void Save(dmlc::JSONWriter *writer) const {
67
      writer->BeginArray(false);
Tianqi Chen committed
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
      writer->WriteArrayItem(node_id);
      writer->WriteArrayItem(index);
      writer->WriteArrayItem(version);
      writer->EndArray();
    }
    void Load(dmlc::JSONReader *reader) {
      reader->BeginArray();
      CHECK(reader->NextArrayItem()) << "invalid json format";
      reader->Read(&node_id);
      CHECK(reader->NextArrayItem()) << "invalid json format";
      reader->Read(&index);
      if (reader->NextArrayItem()) {
        reader->Read(&version);
        CHECK(!reader->NextArrayItem()) << "invalid json format";
      } else {
        version = 0;
      }
    }
  };

88
  // pointer to the graph node
89
  ObjectPtr node;
90 91 92 93
  // inputs
  std::vector<Entry> inputs;
  // control flow dependencies
  std::vector<uint32_t> control_deps;
94 95
  // subgraphs
  std::vector<JSONGraph> subgraphs;
96 97 98 99

  // function to save JSON node.
  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
100 101
    if (node->op() != nullptr) {
      writer->WriteObjectKeyValue("op", node->op()->name);
102 103 104 105 106
    } else {
      std::string json_null = "null";
      writer->WriteObjectKeyValue("op", json_null);
    }
    writer->WriteObjectKeyValue("name", node->attrs.name);
107
    if (node->attrs.dict.size() != 0) {
108 109 110
      // write attributes in order;
      std::map<std::string, std::string> dict(
          node->attrs.dict.begin(), node->attrs.dict.end());
111
      writer->WriteObjectKeyValue("attrs", dict);
112
    }
113
    writer->WriteObjectKeyValue("inputs", inputs);
114 115 116
    if (control_deps.size() != 0) {
      writer->WriteObjectKeyValue("control_deps", control_deps);
    }
117 118 119
    if (subgraphs.size() != 0) {
      writer->WriteObjectKeyValue("subgraphs", subgraphs);
    }
120 121 122 123
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
Tianqi Chen committed
124
    node = Node::Create();
125 126 127 128 129 130
    control_deps.clear();
    dmlc::JSONObjectReadHelper helper;
    std::string op_type_str;
    helper.DeclareField("op", &op_type_str);
    helper.DeclareField("name", &(node->attrs.name));
    helper.DeclareField("inputs", &inputs);
131
    helper.DeclareOptionalField("attrs", &(node->attrs.dict));
132 133
    helper.DeclareOptionalField("attr", &(node->attrs.dict));
    helper.DeclareOptionalField("control_deps", &control_deps);
134
    helper.DeclareOptionalField("subgraphs", &subgraphs);
135 136 137 138 139 140
    // backward compatible code with mxnet graph.
    int backward_source_id;
    std::unordered_map<std::string, std::string> param;
    helper.DeclareOptionalField("param", &param);
    helper.DeclareOptionalField("backward_source_id", &backward_source_id);
    helper.ReadAllFields(reader);
141
    node->attrs.dict.insert(param.begin(), param.end());
142 143 144

    if (op_type_str != "null") {
      try {
145
        node->attrs.op = Op::Get(op_type_str);
146 147 148 149 150 151 152
      } catch (const dmlc::Error &err) {
        std::ostringstream os;
        os << "Failed loading Op " << node->attrs.name
           << " of type " << op_type_str << ": " << err.what();
        throw dmlc::Error(os.str());
      }
    } else {
153
      node->attrs.op = nullptr;
154 155 156 157 158 159 160 161
    }
  }
};

// graph structure to help read/save JSON.
struct JSONGraph {
  std::vector<JSONNode> nodes;
  std::vector<uint32_t> arg_nodes;
162
  std::vector<uint32_t> node_row_ptr;
163
  std::vector<JSONNode::Entry> heads;
164
  std::unordered_map<std::string, std::shared_ptr<any> > attrs;
165 166 167 168 169

  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
    writer->WriteObjectKeyValue("nodes", nodes);
    writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
170
    writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
171 172 173 174 175 176 177 178 179 180 181 182 183
    writer->WriteObjectKeyValue("heads", heads);
    if (attrs.size() != 0) {
      writer->WriteObjectKeyValue("attrs", attrs);
    }
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
    attrs.clear();
    dmlc::JSONObjectReadHelper helper;
    helper.DeclareField("nodes", &nodes);
    helper.DeclareField("arg_nodes", &arg_nodes);
    helper.DeclareField("heads", &heads);
184
    helper.DeclareOptionalField("node_row_ptr", &node_row_ptr);
185 186 187 188 189
    helper.DeclareOptionalField("attrs", &attrs);
    helper.ReadAllFields(reader);
  }
};

190 191 192
void Symbol2JSONGraph(std::shared_ptr<Symbol> src, JSONGraph *jgraph) {
  std::unordered_map<Node*, uint32_t> node2index;
  jgraph->node_row_ptr.push_back(0);
193
  DFSVisit(src->outputs, [&node2index, jgraph](const ObjectPtr& n) {
194 195 196 197 198 199 200 201 202 203 204
    uint32_t nid = static_cast<uint32_t>(jgraph->nodes.size());
    node2index[n.get()] = nid;
    if (n->is_variable()) {
      jgraph->arg_nodes.push_back(nid);
    }
    JSONNode jnode;
    jnode.node = n;
    jnode.inputs.reserve(n->inputs.size());
    for (const NodeEntry& e : n->inputs) {
      jnode.inputs.emplace_back(node2index.at(e.node.get()), e.index, e.version);
    }
205
    for (const ObjectPtr& c : n->control_deps) {
206 207 208 209 210 211 212
      jnode.control_deps.push_back(node2index.at(c.get()));
    }
    jgraph->node_row_ptr.push_back(jgraph->node_row_ptr.back() + n->num_outputs());
    jgraph->nodes.emplace_back(std::move(jnode));
  });
  for (const NodeEntry& e : src->outputs) {
    jgraph->heads.emplace_back(node2index.at(e.node.get()), e.index, e.version);
213
  }
214 215 216 217 218 219 220 221 222 223 224 225 226 227
  // recursively construct subgraphs
  for (JSONNode &jnode : jgraph->nodes) {
    // construct jnode's subgraphs
    const std::vector<std::shared_ptr<Symbol>> &subgraphs = jnode.node->attrs.subgraphs;
    std::vector<JSONGraph> &jsubgraphs = jnode.subgraphs;
    jsubgraphs.resize(subgraphs.size());
    for (uint32_t i = 0; i < subgraphs.size(); ++i) {
      Symbol2JSONGraph(subgraphs[i], &jsubgraphs[i]);
    }
  }
}

std::shared_ptr<Symbol> JSONGraph2Symbol(const JSONGraph &jgraph, bool no_parse) {
  for (const JSONNode &n : jgraph.nodes) {
228 229
    n.node->inputs.reserve(n.inputs.size());
    for (const JSONNode::Entry &e : n.inputs) {
230
      CHECK(e.node_id < jgraph.nodes.size());
231
      n.node->inputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
232 233 234
    }
    n.node->control_deps.reserve(n.control_deps.size());
    for (uint32_t nid : n.control_deps) {
235
      CHECK(nid < jgraph.nodes.size());
236 237
      n.node->control_deps.push_back(jgraph.nodes[nid].node);
    }
238 239 240 241 242 243 244 245
    for (const JSONGraph &subgraph : n.subgraphs) {
      // The "no_parse" option here, is to be compatible with
      // commit cfd3075e85807dcd8f9534c37e053583dee87524
      // (https://github.com/apache/incubator-mxnet/tree/cfd3075e85807dcd8f9534c37e053583dee87524),
      // where the parsing of main graph is deferred until
      // incubator-mxnet/src/nnvm/legacy_json_util.cc:UpgradeJSON_Parse
      n.node->attrs.subgraphs.push_back(JSONGraph2Symbol(subgraph, false));
    }
246 247 248 249 250 251 252
    // rebuild attribute parser
    if (!no_parse && n.node->op() != nullptr && n.node->op()->attr_parser != nullptr) {
      n.node->op()->attr_parser(&(n.node->attrs));
    } else if (!no_parse && n.node->is_variable()) {
      n.node->attrs.parsed =
        Symbol::CreateVariable(n.node->attrs.name).outputs[0].node->attrs.parsed;
    }
253
  }
254
  // consistency check
255
  for (uint32_t nid : jgraph.arg_nodes) {
256
    CHECK(nid < jgraph.nodes.size());
257 258
    CHECK(jgraph.nodes[nid].node->is_variable());
  }
259 260 261
  std::shared_ptr<Symbol> symbol = std::make_shared<Symbol>();
  symbol->outputs.reserve(jgraph.heads.size());
  for (const JSONNode::Entry &e : jgraph.heads) {
262
    CHECK(e.node_id < jgraph.nodes.size());
263 264 265 266
    symbol->outputs.emplace_back(NodeEntry{jgraph.nodes[e.node_id].node, e.index, e.version});
  }
  return symbol;
}
267

268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
// Load a graph from JSON file.
Graph LoadJSON(Graph src) {
  CHECK_NE(src.attrs.count("json"), 0U)
      << "Load JSON require json to be presented.";
  const std::string &json_str =
      nnvm::get<std::string>(*src.attrs.at("json"));
  bool no_parse = false;
  if (src.attrs.count("load_json_no_parse")) {
    no_parse = nnvm::get<bool>(*src.attrs.at("load_json_no_parse"));
  }
  std::istringstream is(json_str);
  dmlc::JSONReader reader(&is);
  JSONGraph jgraph;
  // load in json graph.
  jgraph.Load(&reader);
  std::shared_ptr<Symbol> symbol = JSONGraph2Symbol(jgraph, no_parse);
284 285 286
  // return the graph
  Graph ret;
  ret.attrs = std::move(jgraph.attrs);
287
  ret.outputs = symbol->outputs;
288 289 290 291
  return ret;
}

// save a graph to json
292
Graph SaveJSON(Graph src) {
293 294
  std::shared_ptr<Symbol> src_symbol = std::make_shared<Symbol>();
  src_symbol->outputs = src.outputs;
295
  JSONGraph jgraph;
296
  Symbol2JSONGraph(src_symbol, &jgraph);
297
  jgraph.attrs = src.attrs;
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
  std::ostringstream os;
  dmlc::JSONWriter writer(&os);
  jgraph.Save(&writer);
  Graph ret;
  ret.attrs["json"] = std::make_shared<any>(os.str());
  return ret;
}

// register pass
NNVM_REGISTER_PASS(LoadJSON)
.describe("Return a new Graph, loaded from src.attrs[\"json\"]")
.set_body(LoadJSON)
.set_change_graph(true)
.depend_graph_attr("json");

NNVM_REGISTER_PASS(SaveJSON)
.describe("Return a new empty Graph. Save graph to ret.attrs[\"json\"]")
.set_body(SaveJSON)
.set_change_graph(true)
.provide_graph_attr("json");

319 320 321

DMLC_JSON_ENABLE_ANY(std::string, str);
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
322
DMLC_JSON_ENABLE_ANY(std::vector<std::string>, list_str);
323

324
}  // namespace
325 326
}  // namespace pass
}  // namespace nnvm