Unverified Commit 029388f5 by Tianqi Chen Committed by GitHub

[NODE] General serialzation of leaf objects into bytes. (#5299)

This PR refactors the serialization mechanism to support general
serialization of leaf objects into bytes.

The new feature superceded the original GetGlobalKey feature for singletons.
Added serialization support for runtime::String.
parent 7d670b04
...@@ -98,17 +98,17 @@ class ReflectionVTable { ...@@ -98,17 +98,17 @@ class ReflectionVTable {
typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce); typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
/*! /*!
* \brief creator function. * \brief creator function.
* \param global_key Key that identifies a global single object. * \param repr_bytes Repr bytes to create the object.
* If this is not empty then FGlobalKey must be defined for the object. * If this is not empty then FReprBytes must be defined for the object.
* \return The created function. * \return The created function.
*/ */
typedef ObjectPtr<Object> (*FCreate)(const std::string& global_key); typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
/*! /*!
* \brief Global key function, only needed by global objects. * \brief Function to get a byte representation that can be used to recover the object.
* \param node The node pointer. * \param node The node pointer.
* \return node The global key to the node. * \return bytes The bytes that can be used to recover the object.
*/ */
typedef std::string (*FGlobalKey)(const Object* self); typedef std::string (*FReprBytes)(const Object* self);
/*! /*!
* \brief Dispatch the VisitAttrs function. * \brief Dispatch the VisitAttrs function.
* \param self The pointer to the object. * \param self The pointer to the object.
...@@ -116,11 +116,13 @@ class ReflectionVTable { ...@@ -116,11 +116,13 @@ class ReflectionVTable {
*/ */
inline void VisitAttrs(Object* self, AttrVisitor* visitor) const; inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
/*! /*!
* \brief Get global key of the object, if any. * \brief Get repr bytes if any.
* \param self The pointer to the object. * \param self The pointer to the object.
* \return the global key if object has one, otherwise return empty string. * \param repr_bytes The output repr bytes, can be null, in which case the function
* simply queries if the ReprBytes function exists for the type.
* \return Whether repr bytes exists
*/ */
inline std::string GetGlobalKey(Object* self) const; inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
/*! /*!
* \brief Dispatch the SEqualReduce function. * \brief Dispatch the SEqualReduce function.
* \param self The pointer to the object. * \param self The pointer to the object.
...@@ -141,10 +143,10 @@ class ReflectionVTable { ...@@ -141,10 +143,10 @@ class ReflectionVTable {
* by type_key and global key. * by type_key and global key.
* *
* \param type_key The type key of the object. * \param type_key The type key of the object.
* \param global_key A global key that can be used to uniquely identify the object if any. * \param repr_bytes Bytes representation of the object if any.
*/ */
TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key, TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
const std::string& global_key = "") const; const std::string& repr_bytes = "") const;
/*! /*!
* \brief Get an field object by the attr name. * \brief Get an field object by the attr name.
* \param self The pointer to the object. * \param self The pointer to the object.
...@@ -176,8 +178,8 @@ class ReflectionVTable { ...@@ -176,8 +178,8 @@ class ReflectionVTable {
std::vector<FSHashReduce> fshash_reduce_; std::vector<FSHashReduce> fshash_reduce_;
/*! \brief Creation function. */ /*! \brief Creation function. */
std::vector<FCreate> fcreate_; std::vector<FCreate> fcreate_;
/*! \brief Global key function. */ /*! \brief ReprBytes function. */
std::vector<FGlobalKey> fglobal_key_; std::vector<FReprBytes> frepr_bytes_;
}; };
/*! \brief Registry of a reflection table. */ /*! \brief Registry of a reflection table. */
...@@ -196,13 +198,13 @@ class ReflectionVTable::Registry { ...@@ -196,13 +198,13 @@ class ReflectionVTable::Registry {
return *this; return *this;
} }
/*! /*!
* \brief Set global_key function. * \brief Set bytes repr function.
* \param f The creator function. * \param f The ReprBytes function.
* \return rference to self. * \return rference to self.
*/ */
Registry& set_global_key(FGlobalKey f) { // NOLINT(*) Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->fglobal_key_.size()); CHECK_LT(type_index_, parent_->frepr_bytes_.size());
parent_->fglobal_key_[type_index_] = f; parent_->frepr_bytes_[type_index_] = f;
return *this; return *this;
} }
...@@ -365,7 +367,7 @@ ReflectionVTable::Register() { ...@@ -365,7 +367,7 @@ ReflectionVTable::Register() {
if (tindex >= fvisit_attrs_.size()) { if (tindex >= fvisit_attrs_.size()) {
fvisit_attrs_.resize(tindex + 1, nullptr); fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr); fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr); frepr_bytes_.resize(tindex + 1, nullptr);
fsequal_reduce_.resize(tindex + 1, nullptr); fsequal_reduce_.resize(tindex + 1, nullptr);
fshash_reduce_.resize(tindex + 1, nullptr); fshash_reduce_.resize(tindex + 1, nullptr);
} }
...@@ -392,12 +394,16 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const { ...@@ -392,12 +394,16 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const {
fvisit_attrs_[tindex](self, visitor); fvisit_attrs_[tindex](self, visitor);
} }
inline std::string ReflectionVTable::GetGlobalKey(Object* self) const { inline bool ReflectionVTable::GetReprBytes(const Object* self,
std::string* repr_bytes) const {
uint32_t tindex = self->type_index(); uint32_t tindex = self->type_index();
if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) { if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
return fglobal_key_[tindex](self); if (repr_bytes != nullptr) {
*repr_bytes = frepr_bytes_[tindex](self);
}
return true;
} else { } else {
return std::string(); return false;
} }
} }
......
...@@ -79,8 +79,16 @@ def create_updater_06_to_07(): ...@@ -79,8 +79,16 @@ def create_updater_06_to_07():
return item return item
return _convert return _convert
def _update_global_key(item, _):
item["repr_str"] = item["global_key"]
del item["global_key"]
return item
node_map = { node_map = {
# Base IR # Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var, "relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var, "relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"), "relay.Type": _rename("Type"),
......
...@@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc") ...@@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc")
TVM_REGISTER_NODE_TYPE(EnvFuncNode) TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode) .set_creator(CreateEnvNode)
.set_global_key([](const Object* n) -> std::string { .set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const EnvFuncNode*>(n)->name; return static_cast<const EnvFuncNode*>(n)->name;
}); });
......
...@@ -223,7 +223,7 @@ ObjectPtr<Object> CreateOp(const std::string& name) { ...@@ -223,7 +223,7 @@ ObjectPtr<Object> CreateOp(const std::string& name) {
TVM_REGISTER_NODE_TYPE(OpNode) TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp) .set_creator(CreateOp)
.set_global_key([](const Object* n) { .set_repr_bytes([](const Object* n) {
return static_cast<const OpNode*>(n)->name; return static_cast<const OpNode*>(n)->name;
}); });
......
...@@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) ...@@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(SourceNameNode) TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode) .set_creator(GetSourceNameNode)
.set_global_key([](const Object* n) { .set_repr_bytes([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name; return static_cast<const SourceNameNode*>(n)->name;
}); });
......
...@@ -48,7 +48,21 @@ struct StringObjTrait { ...@@ -48,7 +48,21 @@ struct StringObjTrait {
} }
}; };
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait); struct RefToObjectPtr : public ObjectRef {
static ObjectPtr<Object> Get(const ObjectRef& ref) {
return GetDataPtr<Object>(ref);
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
.set_creator([](const std::string& bytes) {
return RefToObjectPtr::Get(runtime::String(bytes));
})
.set_repr_bytes([](const Object* n) -> std::string {
return GetRef<runtime::String>(
static_cast<const runtime::StringObj*>(n)).operator std::string();
});
struct ADTObjTrait { struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr; static constexpr const std::nullptr_t VisitAttrs = nullptr;
......
...@@ -178,13 +178,13 @@ ReflectionVTable* ReflectionVTable::Global() { ...@@ -178,13 +178,13 @@ ReflectionVTable* ReflectionVTable::Global() {
ObjectPtr<Object> ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key, ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const { const std::string& repr_bytes) const {
uint32_t tindex = Object::TypeKey2Index(type_key); uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) { if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE"; << " is not registered via TVM_REGISTER_NODE_TYPE";
} }
return fcreate_[tindex](global_key); return fcreate_[tindex](repr_bytes);
} }
class NodeAttrSetter : public AttrVisitor { class NodeAttrSetter : public AttrVisitor {
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include <tvm/ir/attrs.h> #include <tvm/ir/attrs.h>
#include <string> #include <string>
#include <cctype>
#include <map> #include <map>
#include "../support/base64.h" #include "../support/base64.h"
...@@ -46,6 +47,26 @@ inline DataType String2Type(std::string s) { ...@@ -46,6 +47,26 @@ inline DataType String2Type(std::string s) {
return DataType(runtime::String2DLDataType(s)); return DataType(runtime::String2DLDataType(s));
} }
inline std::string Base64Decode(std::string s) {
dmlc::MemoryStringStream mstrm(&s);
support::Base64InStream b64strm(&mstrm);
std::string output;
b64strm.InitPosition();
dmlc::Stream* strm = &b64strm;
strm->Read(&output);
return output;
}
inline std::string Base64Encode(std::string s) {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
support::Base64OutStream b64strm(&mstrm);
dmlc::Stream* strm = &b64strm;
strm->Write(s);
b64strm.Finish();
return blob;
}
// indexer to index all the nodes // indexer to index all the nodes
class NodeIndexer : public AttrVisitor { class NodeIndexer : public AttrVisitor {
public: public:
...@@ -103,7 +124,10 @@ class NodeIndexer : public AttrVisitor { ...@@ -103,7 +124,10 @@ class NodeIndexer : public AttrVisitor {
MakeIndex(const_cast<Object*>(kv.second.get())); MakeIndex(const_cast<Object*>(kv.second.get()));
} }
} else { } else {
reflection_->VisitAttrs(node, this); // if the node already have repr bytes, no need to visit Attrs.
if (!reflection_->GetReprBytes(node, nullptr)) {
reflection_->VisitAttrs(node, this);
}
} }
} }
}; };
...@@ -115,8 +139,8 @@ using AttrMap = std::map<std::string, std::string>; ...@@ -115,8 +139,8 @@ using AttrMap = std::map<std::string, std::string>;
struct JSONNode { struct JSONNode {
/*! \brief The type of key of the object. */ /*! \brief The type of key of the object. */
std::string type_key; std::string type_key;
/*! \brief The global key for global object. */ /*! \brief The str repr representation. */
std::string global_key; std::string repr_bytes;
/*! \brief the attributes */ /*! \brief the attributes */
AttrMap attrs; AttrMap attrs;
/*! \brief keys of a map. */ /*! \brief keys of a map. */
...@@ -127,8 +151,15 @@ struct JSONNode { ...@@ -127,8 +151,15 @@ struct JSONNode {
void Save(dmlc::JSONWriter *writer) const { void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject(); writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key); writer->WriteObjectKeyValue("type_key", type_key);
if (global_key.size() != 0) { if (repr_bytes.size() != 0) {
writer->WriteObjectKeyValue("global_key", global_key); // choose to use str representation or base64, based on whether
// the byte representation is printable.
if (std::all_of(repr_bytes.begin(), repr_bytes.end(),
[](char ch) { return std::isprint(ch); })) {
writer->WriteObjectKeyValue("repr_str", repr_bytes);
} else {
writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes));
}
} }
if (attrs.size() != 0) { if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs); writer->WriteObjectKeyValue("attrs", attrs);
...@@ -145,15 +176,24 @@ struct JSONNode { ...@@ -145,15 +176,24 @@ struct JSONNode {
void Load(dmlc::JSONReader *reader) { void Load(dmlc::JSONReader *reader) {
attrs.clear(); attrs.clear();
data.clear(); data.clear();
global_key.clear(); repr_bytes.clear();
type_key.clear(); type_key.clear();
std::string repr_b64, repr_str;
dmlc::JSONObjectReadHelper helper; dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key); helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("global_key", &global_key); helper.DeclareOptionalField("repr_b64", &repr_b64);
helper.DeclareOptionalField("repr_str", &repr_str);
helper.DeclareOptionalField("attrs", &attrs); helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys); helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data); helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader); helper.ReadAllFields(reader);
if (repr_str.size() != 0) {
CHECK_EQ(repr_b64.size(), 0U);
repr_bytes = std::move(repr_str);
} else if (repr_b64.size() != 0) {
repr_bytes = Base64Decode(repr_b64);
}
} }
}; };
...@@ -212,10 +252,8 @@ class JSONAttrGetter : public AttrVisitor { ...@@ -212,10 +252,8 @@ class JSONAttrGetter : public AttrVisitor {
return; return;
} }
node_->type_key = node->GetTypeKey(); node_->type_key = node->GetTypeKey();
node_->global_key = reflection_->GetGlobalKey(node); // do not need to print additional things once we have repr bytes.
// No need to recursively visit fields of global singleton if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return;
// They are registered via the environment.
if (node_->global_key.length() != 0) return;
// populates the fields. // populates the fields.
node_->attrs.clear(); node_->attrs.clear();
...@@ -434,7 +472,7 @@ ObjectRef LoadJSON(std::string json_str) { ...@@ -434,7 +472,7 @@ ObjectRef LoadJSON(std::string json_str) {
for (const JSONNode& jnode : jgraph.nodes) { for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) { if (jnode.type_key.length() != 0) {
ObjectPtr<Object> node = ObjectPtr<Object> node =
reflection->CreateInitObject(jnode.type_key, jnode.global_key); reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
nodes.emplace_back(node); nodes.emplace_back(node);
} else { } else {
nodes.emplace_back(ObjectPtr<Object>()); nodes.emplace_back(ObjectPtr<Object>());
...@@ -447,9 +485,12 @@ ObjectRef LoadJSON(std::string json_str) { ...@@ -447,9 +485,12 @@ ObjectRef LoadJSON(std::string json_str) {
for (size_t i = 0; i < nodes.size(); ++i) { for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i]; setter.node_ = &jgraph.nodes[i];
// do not need to recover content of global singleton object // Skip the nodes that has an repr bytes representation.
// they are registered via the environment // NOTE: the second condition is used to guard the case
if (setter.node_->global_key.length() == 0) { // where the repr bytes itself is an empty string "".
if (setter.node_->repr_bytes.length() == 0 &&
nodes[i] != nullptr &&
!reflection->GetReprBytes(nodes[i].get(), nullptr)) {
setter.Set(nodes[i].get()); setter.Set(nodes[i].get());
} }
} }
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
# under the License. # under the License.
import tvm import tvm
from tvm import relay
from tvm import te from tvm import te
import json import json
...@@ -108,6 +109,22 @@ def test_global_var(): ...@@ -108,6 +109,22 @@ def test_global_var():
assert isinstance(tvar, tvm.ir.GlobalVar) assert isinstance(tvar, tvm.ir.GlobalVar)
def test_op():
nodes = [
{"type_key": ""},
{"type_key": "relay.Op",
"global_key": "nn.conv2d"}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
op = tvm.ir.load_json(json.dumps(data))
assert op == relay.op.get("nn.conv2d")
def test_tir_var(): def test_tir_var():
nodes = [ nodes = [
{"type_key": ""}, {"type_key": ""},
...@@ -132,6 +149,7 @@ def test_tir_var(): ...@@ -132,6 +149,7 @@ def test_tir_var():
if __name__ == "__main__": if __name__ == "__main__":
test_op()
test_type_var() test_type_var()
test_incomplete_type() test_incomplete_type()
test_func_tuple_type() test_func_tuple_type()
......
...@@ -89,7 +89,20 @@ def test_env_func(): ...@@ -89,7 +89,20 @@ def test_env_func():
assert x.func(10) == 11 assert x.func(10) == 11
def test_string():
# non printable str, need to store by b64
s1 = tvm.runtime.String("xy\x01z")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)
# printable str, need to store by repr_str
s1 = tvm.runtime.String("xyz")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)
if __name__ == "__main__": if __name__ == "__main__":
test_string()
test_env_func() test_env_func()
test_make_node() test_make_node()
test_make_smap() test_make_smap()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment