serialization.cc 14.2 KB
Newer Older
1 2 3 4 5 6 7 8
/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
9
 *
10
 *   http://www.apache.org/licenses/LICENSE-2.0
11
 *
12 13 14 15 16 17 18 19
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

20
/*!
21 22
 * \file node/serialization.cc
 * \brief Utilities to serialize TVM AST/IR objects.
23 24
 */
#include <dmlc/json.h>
25
#include <dmlc/memory_io.h>
26 27 28 29 30 31 32 33

#include <tvm/runtime/ndarray.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/node/container.h>
#include <tvm/node/reflection.h>
#include <tvm/node/serialization.h>
#include <tvm/attrs.h>

34
#include <string>
35
#include <map>
36

37
#include "../common/base64.h"
38

39 40
namespace tvm {

41
inline std::string Type2String(const DataType& t) {
42
  return runtime::TVMType2String(Type2TVMType(t));
43 44 45
}

inline Type String2Type(std::string s) {
46
  return TVMType2Type(runtime::String2TVMType(s));
47 48
}

49
// indexer to index all the nodes
50 51
class NodeIndexer : public AttrVisitor {
 public:
52 53 54 55 56
  std::unordered_map<Object*, size_t> node_index_{{nullptr, 0}};
  std::vector<Object*> node_list_{nullptr};
  std::unordered_map<DLTensor*, size_t> tensor_index_;
  std::vector<DLTensor*> tensor_list_;
  ReflectionVTable* reflection_ = ReflectionVTable::Global();
57 58 59 60 61 62 63

  void Visit(const char* key, double* value) final {}
  void Visit(const char* key, int64_t* value) final {}
  void Visit(const char* key, uint64_t* value) final {}
  void Visit(const char* key, int* value) final {}
  void Visit(const char* key, bool* value) final {}
  void Visit(const char* key, std::string* value) final {}
64
  void Visit(const char* key, void** value) final {}
65
  void Visit(const char* key, DataType* value) final {}
66

67 68
  void Visit(const char* key, runtime::NDArray* value) final {
    DLTensor* ptr = const_cast<DLTensor*>((*value).operator->());
69 70 71 72
    if (tensor_index_.count(ptr)) return;
    CHECK_EQ(tensor_index_.size(), tensor_list_.size());
    tensor_index_[ptr] = tensor_list_.size();
    tensor_list_.push_back(ptr);
73
  }
74

75
  void Visit(const char* key, ObjectRef* value) final {
76
    MakeIndex(const_cast<Object*>(value->get()));
77 78
  }

79
  // make index of all the children of node
80 81 82
  void MakeIndex(Object* node) {
    if (node == nullptr) return;
    CHECK(node->IsInstance<Node>());
83

84 85 86 87
    if (node_index_.count(node)) return;
    CHECK_EQ(node_index_.size(), node_list_.size());
    node_index_[node] = node_list_.size();
    node_list_.push_back(node);
88

89
    if (node->IsInstance<ArrayNode>()) {
90 91
      ArrayNode* n = static_cast<ArrayNode*>(node);
      for (const auto& sp : n->data) {
92
        MakeIndex(const_cast<Object*>(sp.get()));
93
      }
94
    } else if (node->IsInstance<MapNode>()) {
95 96
      MapNode* n = static_cast<MapNode*>(node);
      for (const auto& kv : n->data) {
97 98
        MakeIndex(const_cast<Object*>(kv.first.get()));
        MakeIndex(const_cast<Object*>(kv.second.get()));
99
      }
100
    } else if (node->IsInstance<StrMapNode>()) {
101 102
      StrMapNode* n = static_cast<StrMapNode*>(node);
      for (const auto& kv : n->data) {
103
        MakeIndex(const_cast<Object*>(kv.second.get()));
104
      }
105
    } else {
106
      reflection_->VisitAttrs(node, this);
107 108 109 110 111 112 113
    }
  }
};

// use map so attributes are ordered.
using AttrMap = std::map<std::string, std::string>;

114
/*! \brief Node structure for json format. */
115
struct JSONNode {
116
  /*! \brief The type of key of the object. */
117
  std::string type_key;
118
  /*! \brief The global key for global object. */
119
  std::string global_key;
120
  /*! \brief the attributes */
121
  AttrMap attrs;
122
  /*! \brief keys of a map. */
123
  std::vector<std::string> keys;
124
  /*! \brief values of a map or array. */
125 126 127 128 129
  std::vector<size_t> data;

  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
    writer->WriteObjectKeyValue("type_key", type_key);
130 131 132
    if (global_key.size() != 0) {
      writer->WriteObjectKeyValue("global_key", global_key);
    }
133 134 135
    if (attrs.size() != 0) {
      writer->WriteObjectKeyValue("attrs", attrs);
    }
136 137 138
    if (keys.size() != 0) {
      writer->WriteObjectKeyValue("keys", keys);
    }
139 140 141 142 143 144 145 146 147
    if (data.size() != 0) {
      writer->WriteObjectKeyValue("data", data);
    }
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
    attrs.clear();
    data.clear();
148
    global_key.clear();
149 150 151
    type_key.clear();
    dmlc::JSONObjectReadHelper helper;
    helper.DeclareOptionalField("type_key", &type_key);
152
    helper.DeclareOptionalField("global_key", &global_key);
153
    helper.DeclareOptionalField("attrs", &attrs);
154
    helper.DeclareOptionalField("keys", &keys);
155 156 157 158 159
    helper.DeclareOptionalField("data", &data);
    helper.ReadAllFields(reader);
  }
};

160 161
// Helper class to populate the json node
// using the existing index.
162 163
class JSONAttrGetter : public AttrVisitor {
 public:
164
  const std::unordered_map<Object*, size_t>* node_index_;
165
  const std::unordered_map<DLTensor*, size_t>* tensor_index_;
166
  JSONNode* node_;
167
  ReflectionVTable* reflection_ = ReflectionVTable::Global();
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186

  void Visit(const char* key, double* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, int64_t* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, uint64_t* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, int* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, bool* value) final {
    node_->attrs[key] = std::to_string(*value);
  }
  void Visit(const char* key, std::string* value) final {
    node_->attrs[key] = *value;
  }
187 188 189
  void Visit(const char* key, void** value) final {
    LOG(FATAL) << "not allowed to serialize a pointer";
  }
190
  void Visit(const char* key, DataType* value) final {
191 192
    node_->attrs[key] = Type2String(*value);
  }
193

194 195 196 197
  void Visit(const char* key, runtime::NDArray* value) final {
    node_->attrs[key] = std::to_string(
        tensor_index_->at(const_cast<DLTensor*>((*value).operator->())));
  }
198

199
  void Visit(const char* key, ObjectRef* value) final {
200 201
    node_->attrs[key] = std::to_string(
        node_index_->at(const_cast<Object*>(value->get())));
202
  }
203

204
  // Get the node
205 206
  void Get(Object* node) {
    if (node == nullptr) {
207 208 209
      node_->type_key.clear();
      return;
    }
210
    node_->type_key = node->GetTypeKey();
211 212 213 214
    node_->global_key = reflection_->GetGlobalKey(node);
    // No need to recursively visit fields of global singleton
    // They are registered via the environment.
    if (node_->global_key.length() != 0) return;
215

216
    // populates the fields.
217 218
    node_->attrs.clear();
    node_->data.clear();
219

220
    if (node->IsInstance<ArrayNode>()) {
221 222 223
      ArrayNode* n = static_cast<ArrayNode*>(node);
      for (size_t i = 0; i < n->data.size(); ++i) {
        node_->data.push_back(
224
            node_index_->at(const_cast<Object*>(n->data[i].get())));
225
      }
226
    } else if (node->IsInstance<MapNode>()) {
227 228 229
      MapNode* n = static_cast<MapNode*>(node);
      for (const auto& kv : n->data) {
        node_->data.push_back(
230
            node_index_->at(const_cast<Object*>(kv.first.get())));
231
        node_->data.push_back(
232
            node_index_->at(const_cast<Object*>(kv.second.get())));
233
      }
234
    } else if (node->IsInstance<StrMapNode>()) {
235 236 237 238
      StrMapNode* n = static_cast<StrMapNode*>(node);
      for (const auto& kv : n->data) {
        node_->keys.push_back(kv.first);
        node_->data.push_back(
239
            node_index_->at(const_cast<Object*>(kv.second.get())));
240
      }
241
    } else {
242
      // recursively index normal object.
243
      reflection_->VisitAttrs(node, this);
244 245 246 247
    }
  }
};

248 249
// Helper class to set the attributes of a node
// from given json node.
250 251
class JSONAttrSetter : public AttrVisitor {
 public:
252
  const std::vector<ObjectPtr<Object> >* node_list_;
253
  const std::vector<runtime::NDArray>* tensor_list_;
254 255
  JSONNode* node_;

256 257
  ReflectionVTable* reflection_ = ReflectionVTable::Global();

258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
  std::string GetValue(const char* key) const {
    auto it = node_->attrs.find(key);
    if (it == node_->attrs.end()) {
      LOG(FATAL) << "JSONReader: cannot find field " << key;
    }
    return it->second;
  }
  template<typename T>
  void ParseValue(const char* key, T* value) const {
    std::istringstream is(GetValue(key));
    is >> *value;
    if (is.fail()) {
      LOG(FATAL) << "Wrong value format for field " << key;
    }
  }
  void Visit(const char* key, double* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, int64_t* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, uint64_t* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, int* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, bool* value) final {
    ParseValue(key, value);
  }
  void Visit(const char* key, std::string* value) final {
    *value = GetValue(key);
  }
291 292 293
  void Visit(const char* key, void** value) final {
    LOG(FATAL) << "not allowed to deserialize a pointer";
  }
294
  void Visit(const char* key, DataType* value) final {
295 296 297
    std::string stype = GetValue(key);
    *value = String2Type(stype);
  }
298 299 300 301 302 303
  void Visit(const char* key, runtime::NDArray* value) final {
    size_t index;
    ParseValue(key, &index);
    CHECK_LE(index, tensor_list_->size());
    *value = tensor_list_->at(index);
  }
304
  void Visit(const char* key, ObjectRef* value) final {
305 306 307 308
    size_t index;
    ParseValue(key, &index);
    CHECK_LE(index, node_list_->size());
    *value = ObjectRef(node_list_->at(index));
309
  }
310
  // set node to be current JSONNode
311 312
  void Set(Object* node) {
    if (node == nullptr) return;
313 314

    if (node->IsInstance<ArrayNode>()) {
315 316 317
      ArrayNode* n = static_cast<ArrayNode*>(node);
      n->data.clear();
      for (size_t index : node_->data) {
318
        n->data.push_back(ObjectRef(node_list_->at(index)));
319
      }
320
    } else if (node->IsInstance<MapNode>()) {
321 322 323
      MapNode* n = static_cast<MapNode*>(node);
      CHECK_EQ(node_->data.size() % 2, 0U);
      for (size_t i = 0; i < node_->data.size(); i += 2) {
324 325
        n->data[ObjectRef(node_list_->at(node_->data[i]))]
            = ObjectRef(node_list_->at(node_->data[i + 1]));
326
      }
327
    } else if (node->IsInstance<StrMapNode>()) {
328 329 330 331
      StrMapNode* n = static_cast<StrMapNode*>(node);
      CHECK_EQ(node_->data.size(), node_->keys.size());
      for (size_t i = 0; i < node_->data.size(); ++i) {
        n->data[node_->keys[i]]
332
            = ObjectRef(node_list_->at(node_->data[i]));
333
      }
334
    } else {
335
      reflection_->VisitAttrs(node, this);
336 337 338 339 340 341 342 343 344 345
    }
  }
};

// json graph structure to store node
struct JSONGraph {
  // the root of the graph
  size_t root;
  // the nodes of the graph
  std::vector<JSONNode> nodes;
346 347
  // base64 b64ndarrays of arrays
  std::vector<std::string> b64ndarrays;
348 349 350 351 352 353 354
  // global attributes
  AttrMap attrs;

  void Save(dmlc::JSONWriter *writer) const {
    writer->BeginObject();
    writer->WriteObjectKeyValue("root", root);
    writer->WriteObjectKeyValue("nodes", nodes);
355
    writer->WriteObjectKeyValue("b64ndarrays", b64ndarrays);
356 357 358 359 360 361 362 363 364 365 366
    if (attrs.size() != 0) {
      writer->WriteObjectKeyValue("attrs", attrs);
    }
    writer->EndObject();
  }

  void Load(dmlc::JSONReader *reader) {
    attrs.clear();
    dmlc::JSONObjectReadHelper helper;
    helper.DeclareField("root", &root);
    helper.DeclareField("nodes", &nodes);
367
    helper.DeclareOptionalField("b64ndarrays", &b64ndarrays);
368 369 370 371
    helper.DeclareOptionalField("attrs", &attrs);
    helper.ReadAllFields(reader);
  }

372
  static JSONGraph Create(const ObjectRef& root) {
373 374
    JSONGraph g;
    NodeIndexer indexer;
375
    indexer.MakeIndex(const_cast<Object*>(root.get()));
376
    JSONAttrGetter getter;
377 378 379
    getter.node_index_ = &indexer.node_index_;
    getter.tensor_index_ = &indexer.tensor_index_;
    for (Object* n : indexer.node_list_) {
380 381 382 383 384
      JSONNode jnode;
      getter.node_ = &jnode;
      getter.Get(n);
      g.nodes.emplace_back(std::move(jnode));
    }
385
    g.attrs["tvm_version"] = TVM_VERSION;
386
    g.root = indexer.node_index_.at(const_cast<Object*>(root.get()));
387
    // serialize tensor
388
    for (DLTensor* tensor : indexer.tensor_list_) {
389 390 391 392 393 394 395
      std::string blob;
      dmlc::MemoryStringStream mstrm(&blob);
      common::Base64OutStream b64strm(&mstrm);
      runtime::SaveDLTensor(&b64strm, tensor);
      b64strm.Finish();
      g.b64ndarrays.emplace_back(std::move(blob));
    }
396 397 398 399
    return g;
  }
};

400
std::string SaveJSON(const ObjectRef& n) {
401 402 403 404 405 406 407
  auto jgraph = JSONGraph::Create(n);
  std::ostringstream os;
  dmlc::JSONWriter writer(&os);
  jgraph.Save(&writer);
  return os.str();
}

408
ObjectRef LoadJSON(std::string json_str) {
409 410 411 412 413
  std::istringstream is(json_str);
  dmlc::JSONReader reader(&is);
  JSONGraph jgraph;
  // load in json graph.
  jgraph.Load(&reader);
414
  std::vector<ObjectPtr<Object> > nodes;
415 416 417 418 419 420 421 422 423 424
  std::vector<runtime::NDArray> tensors;
  // load in tensors
  for (const std::string& blob : jgraph.b64ndarrays) {
    dmlc::MemoryStringStream mstrm(const_cast<std::string*>(&blob));
    common::Base64InStream b64strm(&mstrm);
    b64strm.InitPosition();
    runtime::NDArray temp;
    CHECK(temp.Load(&b64strm));
    tensors.emplace_back(temp);
  }
425 426
  ReflectionVTable* reflection = ReflectionVTable::Global();

427 428
  // node 0 is always null
  nodes.reserve(jgraph.nodes.size());
429

430 431
  for (const JSONNode& jnode : jgraph.nodes) {
    if (jnode.type_key.length() != 0) {
432 433 434
      ObjectPtr<Object> node =
          reflection->CreateInitObject(jnode.type_key, jnode.global_key);
      nodes.emplace_back(node);
435
    } else {
436
      nodes.emplace_back(ObjectPtr<Object>());
437 438 439 440 441
    }
  }
  CHECK_EQ(nodes.size(), jgraph.nodes.size());
  JSONAttrSetter setter;
  setter.node_list_ = &nodes;
442
  setter.tensor_list_ = &tensors;
443 444 445

  for (size_t i = 0; i < nodes.size(); ++i) {
    setter.node_ = &jgraph.nodes[i];
446 447 448 449 450
    // do not need to recover content of global singleton object
    // they are registered via the environment
    if (setter.node_->global_key.length() == 0) {
      setter.Set(nodes[i].get());
    }
451
  }
452
  return ObjectRef(nodes.at(jgraph.root));
453 454
}
}  // namespace tvm