Unverified Commit 029388f5 by Tianqi Chen Committed by GitHub

[NODE] General serialzation of leaf objects into bytes. (#5299)

This PR refactors the serialization mechanism to support general
serialization of leaf objects into bytes.

The new feature superceded the original GetGlobalKey feature for singletons.
Added serialization support for runtime::String.
parent 7d670b04
......@@ -98,17 +98,17 @@ class ReflectionVTable {
typedef void (*FSHashReduce)(const Object* self, SHashReducer hash_reduce);
/*!
* \brief creator function.
* \param global_key Key that identifies a global single object.
* If this is not empty then FGlobalKey must be defined for the object.
* \param repr_bytes Repr bytes to create the object.
* If this is not empty then FReprBytes must be defined for the object.
* \return The created function.
*/
typedef ObjectPtr<Object> (*FCreate)(const std::string& global_key);
typedef ObjectPtr<Object> (*FCreate)(const std::string& repr_bytes);
/*!
* \brief Global key function, only needed by global objects.
* \brief Function to get a byte representation that can be used to recover the object.
* \param node The node pointer.
* \return node The global key to the node.
* \return bytes The bytes that can be used to recover the object.
*/
typedef std::string (*FGlobalKey)(const Object* self);
typedef std::string (*FReprBytes)(const Object* self);
/*!
* \brief Dispatch the VisitAttrs function.
* \param self The pointer to the object.
......@@ -116,11 +116,13 @@ class ReflectionVTable {
*/
inline void VisitAttrs(Object* self, AttrVisitor* visitor) const;
/*!
* \brief Get global key of the object, if any.
* \brief Get repr bytes if any.
* \param self The pointer to the object.
* \return the global key if object has one, otherwise return empty string.
* \param repr_bytes The output repr bytes, can be null, in which case the function
* simply queries if the ReprBytes function exists for the type.
* \return Whether repr bytes exists
*/
inline std::string GetGlobalKey(Object* self) const;
inline bool GetReprBytes(const Object* self, std::string* repr_bytes) const;
/*!
* \brief Dispatch the SEqualReduce function.
* \param self The pointer to the object.
......@@ -141,10 +143,10 @@ class ReflectionVTable {
* by type_key and global key.
*
* \param type_key The type key of the object.
* \param global_key A global key that can be used to uniquely identify the object if any.
* \param repr_bytes Bytes representation of the object if any.
*/
TVM_DLL ObjectPtr<Object> CreateInitObject(const std::string& type_key,
const std::string& global_key = "") const;
const std::string& repr_bytes = "") const;
/*!
* \brief Get an field object by the attr name.
* \param self The pointer to the object.
......@@ -176,8 +178,8 @@ class ReflectionVTable {
std::vector<FSHashReduce> fshash_reduce_;
/*! \brief Creation function. */
std::vector<FCreate> fcreate_;
/*! \brief Global key function. */
std::vector<FGlobalKey> fglobal_key_;
/*! \brief ReprBytes function. */
std::vector<FReprBytes> frepr_bytes_;
};
/*! \brief Registry of a reflection table. */
......@@ -196,13 +198,13 @@ class ReflectionVTable::Registry {
return *this;
}
/*!
* \brief Set global_key function.
* \param f The creator function.
* \brief Set bytes repr function.
* \param f The ReprBytes function.
* \return rference to self.
*/
Registry& set_global_key(FGlobalKey f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->fglobal_key_.size());
parent_->fglobal_key_[type_index_] = f;
Registry& set_repr_bytes(FReprBytes f) { // NOLINT(*)
CHECK_LT(type_index_, parent_->frepr_bytes_.size());
parent_->frepr_bytes_[type_index_] = f;
return *this;
}
......@@ -365,7 +367,7 @@ ReflectionVTable::Register() {
if (tindex >= fvisit_attrs_.size()) {
fvisit_attrs_.resize(tindex + 1, nullptr);
fcreate_.resize(tindex + 1, nullptr);
fglobal_key_.resize(tindex + 1, nullptr);
frepr_bytes_.resize(tindex + 1, nullptr);
fsequal_reduce_.resize(tindex + 1, nullptr);
fshash_reduce_.resize(tindex + 1, nullptr);
}
......@@ -392,12 +394,16 @@ VisitAttrs(Object* self, AttrVisitor* visitor) const {
fvisit_attrs_[tindex](self, visitor);
}
inline std::string ReflectionVTable::GetGlobalKey(Object* self) const {
inline bool ReflectionVTable::GetReprBytes(const Object* self,
std::string* repr_bytes) const {
uint32_t tindex = self->type_index();
if (tindex < fglobal_key_.size() && fglobal_key_[tindex] != nullptr) {
return fglobal_key_[tindex](self);
if (tindex < frepr_bytes_.size() && frepr_bytes_[tindex] != nullptr) {
if (repr_bytes != nullptr) {
*repr_bytes = frepr_bytes_[tindex](self);
}
return true;
} else {
return std::string();
return false;
}
}
......
......@@ -79,8 +79,16 @@ def create_updater_06_to_07():
return item
return _convert
def _update_global_key(item, _):
item["repr_str"] = item["global_key"]
del item["global_key"]
return item
node_map = {
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.TypeVar": _ftype_var,
"relay.GlobalTypeVar": _ftype_var,
"relay.Type": _rename("Type"),
......
......@@ -69,7 +69,7 @@ TVM_REGISTER_GLOBAL("ir.EnvFuncGetPackedFunc")
TVM_REGISTER_NODE_TYPE(EnvFuncNode)
.set_creator(CreateEnvNode)
.set_global_key([](const Object* n) -> std::string {
.set_repr_bytes([](const Object* n) -> std::string {
return static_cast<const EnvFuncNode*>(n)->name;
});
......
......@@ -223,7 +223,7 @@ ObjectPtr<Object> CreateOp(const std::string& name) {
TVM_REGISTER_NODE_TYPE(OpNode)
.set_creator(CreateOp)
.set_global_key([](const Object* n) {
.set_repr_bytes([](const Object* n) {
return static_cast<const OpNode*>(n)->name;
});
......
......@@ -56,7 +56,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
TVM_REGISTER_NODE_TYPE(SourceNameNode)
.set_creator(GetSourceNameNode)
.set_global_key([](const Object* n) {
.set_repr_bytes([](const Object* n) {
return static_cast<const SourceNameNode*>(n)->name;
});
......
......@@ -48,7 +48,21 @@ struct StringObjTrait {
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait);
struct RefToObjectPtr : public ObjectRef {
static ObjectPtr<Object> Get(const ObjectRef& ref) {
return GetDataPtr<Object>(ref);
}
};
TVM_REGISTER_REFLECTION_VTABLE(runtime::StringObj, StringObjTrait)
.set_creator([](const std::string& bytes) {
return RefToObjectPtr::Get(runtime::String(bytes));
})
.set_repr_bytes([](const Object* n) -> std::string {
return GetRef<runtime::String>(
static_cast<const runtime::StringObj*>(n)).operator std::string();
});
struct ADTObjTrait {
static constexpr const std::nullptr_t VisitAttrs = nullptr;
......
......@@ -178,13 +178,13 @@ ReflectionVTable* ReflectionVTable::Global() {
ObjectPtr<Object>
ReflectionVTable::CreateInitObject(const std::string& type_key,
const std::string& global_key) const {
const std::string& repr_bytes) const {
uint32_t tindex = Object::TypeKey2Index(type_key);
if (tindex >= fcreate_.size() || fcreate_[tindex] == nullptr) {
LOG(FATAL) << "TypeError: " << type_key
<< " is not registered via TVM_REGISTER_NODE_TYPE";
}
return fcreate_[tindex](global_key);
return fcreate_[tindex](repr_bytes);
}
class NodeAttrSetter : public AttrVisitor {
......
......@@ -32,6 +32,7 @@
#include <tvm/ir/attrs.h>
#include <string>
#include <cctype>
#include <map>
#include "../support/base64.h"
......@@ -46,6 +47,26 @@ inline DataType String2Type(std::string s) {
return DataType(runtime::String2DLDataType(s));
}
inline std::string Base64Decode(std::string s) {
dmlc::MemoryStringStream mstrm(&s);
support::Base64InStream b64strm(&mstrm);
std::string output;
b64strm.InitPosition();
dmlc::Stream* strm = &b64strm;
strm->Read(&output);
return output;
}
inline std::string Base64Encode(std::string s) {
std::string blob;
dmlc::MemoryStringStream mstrm(&blob);
support::Base64OutStream b64strm(&mstrm);
dmlc::Stream* strm = &b64strm;
strm->Write(s);
b64strm.Finish();
return blob;
}
// indexer to index all the nodes
class NodeIndexer : public AttrVisitor {
public:
......@@ -103,7 +124,10 @@ class NodeIndexer : public AttrVisitor {
MakeIndex(const_cast<Object*>(kv.second.get()));
}
} else {
reflection_->VisitAttrs(node, this);
// if the node already have repr bytes, no need to visit Attrs.
if (!reflection_->GetReprBytes(node, nullptr)) {
reflection_->VisitAttrs(node, this);
}
}
}
};
......@@ -115,8 +139,8 @@ using AttrMap = std::map<std::string, std::string>;
struct JSONNode {
/*! \brief The type of key of the object. */
std::string type_key;
/*! \brief The global key for global object. */
std::string global_key;
/*! \brief The str repr representation. */
std::string repr_bytes;
/*! \brief the attributes */
AttrMap attrs;
/*! \brief keys of a map. */
......@@ -127,8 +151,15 @@ struct JSONNode {
void Save(dmlc::JSONWriter *writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("type_key", type_key);
if (global_key.size() != 0) {
writer->WriteObjectKeyValue("global_key", global_key);
if (repr_bytes.size() != 0) {
// choose to use str representation or base64, based on whether
// the byte representation is printable.
if (std::all_of(repr_bytes.begin(), repr_bytes.end(),
[](char ch) { return std::isprint(ch); })) {
writer->WriteObjectKeyValue("repr_str", repr_bytes);
} else {
writer->WriteObjectKeyValue("repr_b64", Base64Encode(repr_bytes));
}
}
if (attrs.size() != 0) {
writer->WriteObjectKeyValue("attrs", attrs);
......@@ -145,15 +176,24 @@ struct JSONNode {
void Load(dmlc::JSONReader *reader) {
attrs.clear();
data.clear();
global_key.clear();
repr_bytes.clear();
type_key.clear();
std::string repr_b64, repr_str;
dmlc::JSONObjectReadHelper helper;
helper.DeclareOptionalField("type_key", &type_key);
helper.DeclareOptionalField("global_key", &global_key);
helper.DeclareOptionalField("repr_b64", &repr_b64);
helper.DeclareOptionalField("repr_str", &repr_str);
helper.DeclareOptionalField("attrs", &attrs);
helper.DeclareOptionalField("keys", &keys);
helper.DeclareOptionalField("data", &data);
helper.ReadAllFields(reader);
if (repr_str.size() != 0) {
CHECK_EQ(repr_b64.size(), 0U);
repr_bytes = std::move(repr_str);
} else if (repr_b64.size() != 0) {
repr_bytes = Base64Decode(repr_b64);
}
}
};
......@@ -212,10 +252,8 @@ class JSONAttrGetter : public AttrVisitor {
return;
}
node_->type_key = node->GetTypeKey();
node_->global_key = reflection_->GetGlobalKey(node);
// No need to recursively visit fields of global singleton
// They are registered via the environment.
if (node_->global_key.length() != 0) return;
// do not need to print additional things once we have repr bytes.
if (reflection_->GetReprBytes(node, &(node_->repr_bytes))) return;
// populates the fields.
node_->attrs.clear();
......@@ -434,7 +472,7 @@ ObjectRef LoadJSON(std::string json_str) {
for (const JSONNode& jnode : jgraph.nodes) {
if (jnode.type_key.length() != 0) {
ObjectPtr<Object> node =
reflection->CreateInitObject(jnode.type_key, jnode.global_key);
reflection->CreateInitObject(jnode.type_key, jnode.repr_bytes);
nodes.emplace_back(node);
} else {
nodes.emplace_back(ObjectPtr<Object>());
......@@ -447,9 +485,12 @@ ObjectRef LoadJSON(std::string json_str) {
for (size_t i = 0; i < nodes.size(); ++i) {
setter.node_ = &jgraph.nodes[i];
// do not need to recover content of global singleton object
// they are registered via the environment
if (setter.node_->global_key.length() == 0) {
// Skip the nodes that has an repr bytes representation.
// NOTE: the second condition is used to guard the case
// where the repr bytes itself is an empty string "".
if (setter.node_->repr_bytes.length() == 0 &&
nodes[i] != nullptr &&
!reflection->GetReprBytes(nodes[i].get(), nullptr)) {
setter.Set(nodes[i].get());
}
}
......
......@@ -16,6 +16,7 @@
# under the License.
import tvm
from tvm import relay
from tvm import te
import json
......@@ -108,6 +109,22 @@ def test_global_var():
assert isinstance(tvar, tvm.ir.GlobalVar)
def test_op():
nodes = [
{"type_key": ""},
{"type_key": "relay.Op",
"global_key": "nn.conv2d"}
]
data = {
"root" : 1,
"nodes": nodes,
"attrs": {"tvm_version": "0.6.0"},
"b64ndarrays": [],
}
op = tvm.ir.load_json(json.dumps(data))
assert op == relay.op.get("nn.conv2d")
def test_tir_var():
nodes = [
{"type_key": ""},
......@@ -132,6 +149,7 @@ def test_tir_var():
if __name__ == "__main__":
test_op()
test_type_var()
test_incomplete_type()
test_func_tuple_type()
......
......@@ -89,7 +89,20 @@ def test_env_func():
assert x.func(10) == 11
def test_string():
# non printable str, need to store by b64
s1 = tvm.runtime.String("xy\x01z")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)
# printable str, need to store by repr_str
s1 = tvm.runtime.String("xyz")
s2 = tvm.ir.load_json(tvm.ir.save_json(s1))
tvm.ir.assert_structural_equal(s1, s2)
if __name__ == "__main__":
test_string()
test_env_func()
test_make_node()
test_make_smap()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment