serialization.cc 15.8 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file node/serialization.cc
 * \brief Utilities to serialize TVM AST/IR objects.
23 24
 */
#include <dmlc/json.h>
25
#include <dmlc/memory_io.h>
26
#include <tvm/runtime/registry.h>
27 28 29 30 31
#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/node/serialization.h>
32
#include <tvm/ir/attrs.h>
33

34
#include <string>
35
#include <cctype>
36
#include <map>
37

38
#include "../support/base64.h"
39

40 41
namespace tvm {

42
inline std::string Type2String(const DataType& t) {
43
  return runtime::DLDataType2String(t);
44 45
}

46
inline DataType String2Type(std::string s) {
47
  return DataType(runtime::String2DLDataType(s));
48 49
}

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
inline std::string Base64Decode(std::string s) {
  dmlc::MemoryStringStream mstrm(&s);
  support::Base64InStream b64strm(&mstrm);
  std::string output;
  b64strm.InitPosition();
  dmlc::Stream* strm = &b64strm;
  strm->Read(&output);
  return output;
}

inline std::string Base64Encode(std::string s) {
  std::string blob;
  dmlc::MemoryStringStream mstrm(&blob);
  support::Base64OutStream b64strm(&mstrm);
  dmlc::Stream* strm = &b64strm;
  strm->Write(s);
  b64strm.Finish();
  return blob;
}

70
// indexer to index all the nodes
71 72
class NodeIndexer : public AttrVisitor {
 public:
73 74 75 76 77
  std::unordered_map<Object*, size_t> node_index_{{nullptr, 0}};
  std::vector<Object*> node_list_{nullptr};
  std::unordered_map<DLTensor*, size_t> tensor_index_;
  std::vector<DLTensor*> tensor_list_;
  ReflectionVTable* reflection_ = ReflectionVTable::Global();
78 79 80 81 82 83 84

  void Visit(const char* key, double* value) final {}
  void Visit(const char* key, int64_t* value) final {}
  void Visit(const char* key, uint64_t* value) final {}
  void Visit(const char* key, int* value) final {}
  void Visit(const char* key, bool* value) final {}
  void Visit(const char* key, std::string* value) final {}
85
  void Visit(const char* key, void** value) final {}
86
  void Visit(const char* key, DataType* value) final {}
87

88 89
  void Visit(const char* key, runtime::NDArray* value) final {
    DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
90 91 92 93
    if (tensor_index_.count(ptr)) return;
    CHECK_EQ(tensor_index_.size(), tensor_list_.size());
    tensor_index_[ptr] = tensor_list_.size();
    tensor_list_.push_back(ptr);
94
  }
95

96
  void Visit(const char* key, ObjectRef* value) final {
97
    MakeIndex(const_cast<Object*>(value->get()));
98 99
  }

100
  // make index of all the children of node
101 102
  void MakeIndex(Object* node) {
    if (node == nullptr) return;
103
    CHECK(node->IsInstance<Object>());
104

105 106 107 108
    if (node_index_.count(node)) return;
    CHECK_EQ(node_index_.size(), node_list_.size());
    node_index_[node] = node_list_.size();
    node_list_.push_back(node);
109

110
    if (node->IsInstance<ArrayNode>()) {
111 112
      ArrayNode* n = static_cast<ArrayNode*>(node);
      for (const auto& sp : n->data) {
113
        MakeIndex(const_cast<Object*>(sp.get()));
114
      }
115
    } else if (node->IsInstance<MapNode>()) {
116 117
      MapNode* n = static_cast<MapNode*>(node);
      for (const auto& kv : n->data) {
118 119
        MakeIndex(const_cast<Object*>(kv.first.get()));
        MakeIndex(const_cast<Object*>(kv.second.get()));
120
      }
121
    } else if (node->IsInstance<StrMapNode>()) {
122 123
      StrMapNode* n = static_cast<StrMapNode*>(node);
      for (const auto& kv : n->data) {
124
        MakeIndex(const_cast<Object*>(kv.second.get()));
125
      }
126
    } else {
127 128 129 130
      // if the node already have repr bytes, no need to visit Attrs.
      if (!reflection_->GetReprBytes(node, nullptr)) {
        reflection_->VisitAttrs(node, this);
      }
131 132 133 134 135 136 137
    }
  }
};

// use map so attributes are ordered.
using AttrMap = std::map<std::string, std::string>;

138
/*! \brief Node structure for json format. */
139
struct JSONNode {
140
  /*! \brief The type of key of the object. */
141
  std::string type_key;
142 143
  /*! \brief The str repr representation. */
  std::string repr_bytes;
144
  /*! \brief the attributes */
145
  AttrMap attrs;
146
  /*! \brief keys of a map. */
147
  std::vector<std::string> keys;
148
  /*! \brief values of a map or array. */
149 150 151 152 153
  std::vector<size_t> data;

  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
    writer->WriteObjectKeyValue("type_key", type_key);
154 155 156 157 158 159 160 161 162
    if (repr_bytes.size() != 0) {
      // choose to use str representation or base64, based on whether
      // the byte representation is printable.
      if (std::all_of(repr_bytes.begin(), repr_bytes.end(),
                      [](char ch) { return std::isprint(ch); })) {
        writer->WriteObjectKeyValue("repr_str", repr_bytes);
      } else {
        writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes));
      }
163
    }
164 165 166
    if (attrs.size() != 0) {
      writer->WriteObjectKeyValue("attrs", attrs);
    }
167 168 169
    if (keys.size() != 0) {
      writer->WriteObjectKeyValue("keys", keys);
    }
170 171 172 173 174 175 176 177 178
    if (data.size() != 0) {
      writer->WriteObjectKeyValue("data", data);
    }
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
    attrs.clear();
    data.clear();
179
    repr_bytes.clear();
180
    type_key.clear();
181
    std::string repr_b64, repr_str;
182 183
    dmlc::JSONObjectReadHelper helper;
    helper.DeclareOptionalField("type_key", &type_key);
184 185
    helper.DeclareOptionalField("repr_b64", &repr_b64);
    helper.DeclareOptionalField("repr_str", &repr_str);
186
    helper.DeclareOptionalField("attrs", &attrs);
187
    helper.DeclareOptionalField("keys", &keys);
188 189
    helper.DeclareOptionalField("data", &data);
    helper.ReadAllFields(reader);
190 191 192 193 194 195 196

    if (repr_str.size() != 0) {
      CHECK_EQ(repr_b64.size(), 0U);
      repr_bytes = std::move(repr_str);
    } else if (repr_b64.size() != 0) {
      repr_bytes = Base64Decode(repr_b64);
    }
197 198 199
  }
};

200 201
// Helper class to populate the json node
// using the existing index.
202 203
class JSONAttrGetter : public AttrVisitor {
 public:
204
  const std::unordered_map<Object*, size_t>* node_index_;
205
  const std::unordered_map<DLTensor*, size_t>* tensor_index_;
206
  JSONNode* node_;
207
  ReflectionVTable* reflection_ = ReflectionVTable::Global();
208 209

  void Visit(const char* key, double* value) final {
210 211 212 213 214
    std::ostringstream s;
    // Type <double> have approximately 16 decimal digits
    s.precision(16);
    s << (*value);
    node_->attrs[key] = s.str();
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230
  }
  void Visit(const char* key, int64_t* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, uint64_t* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, int* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, bool* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, std::string* value) final {
    node_->attrs[key] = *value;
  }
231 232 233
  void Visit(const char* key, void** value) final {
    LOG(FATAL) << "not allowed to serialize a pointer";
  }
234
  void Visit(const char* key, DataType* value) final {
235 236
    node_->attrs[key] = Type2String(*value);
  }
237

238 239 240 241
  void Visit(const char* key, runtime::NDArray* value) final {
    node_->attrs[key] = std::to_string(
        tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
  }
242

243
  void Visit(const char* key, ObjectRef* value) final {
244 245
    node_->attrs[key] = std::to_string(
        node_index_->at(const_cast<Object*>(value->get())));
246
  }
247

248
  // Get the node
249 250
  void Get(Object* node) {
    if (node == nullptr) {
251 252 253
      node_->type_key.clear();
      return;
    }
254
    node_->type_key = node->GetTypeKey();
255 256
    // do not need to print additional things once we have repr bytes.
    if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return;
257

258
    // populates the fields.
259 260
    node_->attrs.clear();
    node_->data.clear();
261

262
    if (node->IsInstance<ArrayNode>()) {
263 264 265
      ArrayNode* n = static_cast<ArrayNode*>(node);
      for (size_t i = 0; i < n->data.size(); ++i) {
        node_->data.push_back(
266
            node_index_->at(const_cast<Object*>(n->data[i].get())));
267
      }
268
    } else if (node->IsInstance<MapNode>()) {
269 270 271
      MapNode* n = static_cast<MapNode*>(node);
      for (const auto& kv : n->data) {
        node_->data.push_back(
272
            node_index_->at(const_cast<Object*>(kv.first.get())));
273
        node_->data.push_back(
274
            node_index_->at(const_cast<Object*>(kv.second.get())));
275
      }
276
    } else if (node->IsInstance<StrMapNode>()) {
277 278 279 280
      StrMapNode* n = static_cast<StrMapNode*>(node);
      for (const auto& kv : n->data) {
        node_->keys.push_back(kv.first);
        node_->data.push_back(
281
            node_index_->at(const_cast<Object*>(kv.second.get())));
282
      }
283
    } else {
284
      // recursively index normal object.
285
      reflection_->VisitAttrs(node, this);
286 287 288 289
    }
  }
};

290 291
// Helper class to set the attributes of a node
// from given json node.
292 293
class JSONAttrSetter : public AttrVisitor {
 public:
294
  const std::vector<ObjectPtr<Object> >* node_list_;
295
  const std::vector<runtime::NDArray>* tensor_list_;
296 297
  JSONNode* node_;

298 299
  ReflectionVTable* reflection_ = ReflectionVTable::Global();

300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332
  std::string GetValue(const char* key) const {
    auto it = node_->attrs.find(key);
    if (it == node_->attrs.end()) {
      LOG(FATAL) << "JSONReader: cannot find field " << key;
    }
    return it->second;
  }
  template<typename T>
  void ParseValue(const char* key, T* value) const {
    std::istringstream is(GetValue(key));
    is >> *value;
    if (is.fail()) {
      LOG(FATAL) << "Wrong value format for field " << key;
    }
  }
  void Visit(const char* key, double* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, int64_t* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, uint64_t* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, int* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, bool* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, std::string* value) final {
    *value = GetValue(key);
  }
333 334 335
  void Visit(const char* key, void** value) final {
    LOG(FATAL) << "not allowed to deserialize a pointer";
  }
336
  void Visit(const char* key, DataType* value) final {
337 338 339
    std::string stype = GetValue(key);
    *value = String2Type(stype);
  }
340 341 342 343 344 345
  void Visit(const char* key, runtime::NDArray* value) final {
    size_t index;
    ParseValue(key, &index);
    CHECK_LE(index, tensor_list_->size());
    *value = tensor_list_->at(index);
  }
346
  void Visit(const char* key, ObjectRef* value) final {
347 348 349 350
    size_t index;
    ParseValue(key, &index);
    CHECK_LE(index, node_list_->size());
    *value = ObjectRef(node_list_->at(index));
351
  }
352
  // set node to be current JSONNode
353 354
  void Set(Object* node) {
    if (node == nullptr) return;
355 356

    if (node->IsInstance<ArrayNode>()) {
357 358 359
      ArrayNode* n = static_cast<ArrayNode*>(node);
      n->data.clear();
      for (size_t index : node_->data) {
360
        n->data.push_back(ObjectRef(node_list_->at(index)));
361
      }
362
    } else if (node->IsInstance<MapNode>()) {
363 364 365
      MapNode* n = static_cast<MapNode*>(node);
      CHECK_EQ(node_->data.size() % 2, 0U);
      for (size_t i = 0; i < node_->data.size(); i += 2) {
366 367
        n->data[ObjectRef(node_list_->at(node_->data[i]))]
            = ObjectRef(node_list_->at(node_->data[i + 1]));
368
      }
369
    } else if (node->IsInstance<StrMapNode>()) {
370 371 372 373
      StrMapNode* n = static_cast<StrMapNode*>(node);
      CHECK_EQ(node_->data.size(), node_->keys.size());
      for (size_t i = 0; i < node_->data.size(); ++i) {
        n->data[node_->keys[i]]
374
            = ObjectRef(node_list_->at(node_->data[i]));
375
      }
376
    } else {
377
      reflection_->VisitAttrs(node, this);
378 379 380 381 382 383 384 385 386 387
    }
  }
};

// json graph structure to store node
struct JSONGraph {
  // the root of the graph
  size_t root;
  // the nodes of the graph
  std::vector<JSONNode> nodes;
388 389
  // base64 b64ndarrays of arrays
  std::vector<std::string> b64ndarrays;
390 391 392 393 394 395 396
  // global attributes
  AttrMap attrs;

  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
    writer->WriteObjectKeyValue("root", root);
    writer->WriteObjectKeyValue("nodes", nodes);
397
    writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays);
398 399 400 401 402 403 404 405 406 407 408
    if (attrs.size() != 0) {
      writer->WriteObjectKeyValue("attrs", attrs);
    }
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
    attrs.clear();
    dmlc::JSONObjectReadHelper helper;
    helper.DeclareField("root", &root);
    helper.DeclareField("nodes", &nodes);
409
    helper.DeclareOptionalField("b64ndarrays", &b64ndarrays);
410 411 412 413
    helper.DeclareOptionalField("attrs", &attrs);
    helper.ReadAllFields(reader);
  }

414
  static JSONGraph Create(const ObjectRef& root) {
415 416
    JSONGraph g;
    NodeIndexer indexer;
417
    indexer.MakeIndex(const_cast<Object*>(root.get()));
418
    JSONAttrGetter getter;
419 420 421
    getter.node_index_ = &indexer.node_index_;
    getter.tensor_index_ = &indexer.tensor_index_;
    for (Object* n : indexer.node_list_) {
422 423 424 425 426
      JSONNode jnode;
      getter.node_ = &jnode;
      getter.Get(n);
      g.nodes.emplace_back(std::move(jnode));
    }
427
    g.attrs["tvm_version"] = TVM_VERSION;
428
    g.root = indexer.node_index_.at(const_cast<Object*>(root.get()));
429
    // serialize tensor
430
    for (DLTensor* tensor : indexer.tensor_list_) {
431 432
      std::string blob;
      dmlc::MemoryStringStream mstrm(&blob);
433
      support::Base64OutStream b64strm(&mstrm);
434 435 436 437
      runtime::SaveDLTensor(&b64strm, tensor);
      b64strm.Finish();
      g.b64ndarrays.emplace_back(std::move(blob));
    }
438 439 440 441
    return g;
  }
};

442
std::string SaveJSON(const ObjectRef& n) {
443 444 445 446 447 448 449
  auto jgraph = JSONGraph::Create(n);
  std::ostringstream os;
  dmlc::JSONWriter writer(&os);
  jgraph.Save(&writer);
  return os.str();
}

450
ObjectRef LoadJSON(std::string json_str) {
451 452 453 454 455
  std::istringstream is(json_str);
  dmlc::JSONReader reader(&is);
  JSONGraph jgraph;
  // load in json graph.
  jgraph.Load(&reader);
456
  std::vector<ObjectPtr<Object> > nodes;
457 458 459 460
  std::vector<runtime::NDArray> tensors;
  // load in tensors
  for (const std::string& blob : jgraph.b64ndarrays) {
    dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
461
    support::Base64InStream b64strm(&mstrm);
462 463 464 465 466
    b64strm.InitPosition();
    runtime::NDArray temp;
    CHECK(temp.Load(&b64strm));
    tensors.emplace_back(temp);
  }
467 468
  ReflectionVTable* reflection = ReflectionVTable::Global();

469 470
  // node 0 is always null
  nodes.reserve(jgraph.nodes.size());
471

472 473
  for (const JSONNode& jnode : jgraph.nodes) {
    if (jnode.type_key.length() != 0) {
474
      ObjectPtr<Object> node =
475
          reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
476
      nodes.emplace_back(node);
477
    } else {
478
      nodes.emplace_back(ObjectPtr<Object>());
479 480 481 482 483
    }
  }
  CHECK_EQ(nodes.size(), jgraph.nodes.size());
  JSONAttrSetter setter;
  setter.node_list_ = &nodes;
484
  setter.tensor_list_ = &tensors;
485 486 487

  for (size_t i = 0; i < nodes.size(); ++i) {
    setter.node_ = &jgraph.nodes[i];
488 489 490 491 492 493
    // Skip the nodes that has an repr bytes representation.
    // NOTE: the second condition is used to guard the case
    // where the repr bytes itself is an empty string "".
    if (setter.node_->repr_bytes.length() == 0 &&
        nodes[i] != nullptr &&
        !reflection->GetReprBytes(nodes[i].get(), nullptr)) {
494 495
      setter.Set(nodes[i].get());
    }
496
  }
497
  return ObjectRef(nodes.at(jgraph.root));
498
}
499 500 501 502 503 504

TVM_REGISTER_GLOBAL("node.SaveJSON")
.set_body_typed(SaveJSON);

TVM_REGISTER_GLOBAL("node.LoadJSON")
.set_body_typed(LoadJSON);
505
}  // namespace tvm