Unverified Commit e0af5c20 by Tianqi Chen Committed by GitHub

[RELAY] TextPrinter: Use Map Format (#2553)

parent e2970b22
......@@ -42,11 +42,11 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* we support a meta-data section in the text format.
* We allow the text format to refer to a node in the meta-data section.
*
* The meta-data section is a json serialized string of an Array<NodeRef>.
* The meta-data section is a json serialized string of an Map<string, Array<NodeRef>>.
* Each element in the meta-data section can be referenced by the text format.
* Each meta data node is printed in the following format.
*
* meta.<type-key-of-node>(<index-in-meta-section>)
* meta[type-key-of-node>][<index-in-meta-section>]
*
* Specifically, consider the following IR(constructed by python).
*
......@@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
*
* \code
*
* fn (%x: Tensor[(meta.Variable(id=0),), float32]) {
* fn (%x: Tensor[(meta[Variable][0],), float32]) {
* %x
* }
* # Meta data section is a json-serialized string
......@@ -74,7 +74,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
*
* Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta-data section,
* we print it as meta.Variable(0).
* we print it as meta[Variable][0].
*
* The text parser can recover this object by loading from the corresponding
* location in the meta data section.
......@@ -91,18 +91,18 @@ class TextMetaDataContext {
* \return A string representation of the meta node.
*/
std::string GetMetaNode(const NodeRef& node) {
std::ostringstream os;
auto it = meta_index_.find(node);
int64_t index;
if (it != meta_index_.end()) {
index = it->second;
} else {
index = static_cast<int64_t>(meta_data_.size());
meta_data_.push_back(node);
meta_index_[node] = index;
auto it = meta_repr_.find(node);
if (it != meta_repr_.end()) {
return it->second;
}
os << "meta." << node->type_key() << "(id=" << index << ")";
return os.str();
Array<NodeRef>& mvector =
meta_data_[node->type_key()];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
std::ostringstream os;
os << "meta[" << node->type_key() << "][" << index << "]";
meta_repr_[node] = os.str();
return meta_repr_[node];
}
/*!
* \brief Get the metadata section in json format.
......@@ -110,7 +110,8 @@ class TextMetaDataContext {
*/
std::string GetMetaSection() const {
if (meta_data_.size() == 0) return std::string();
return SaveJSON(Array<NodeRef>(meta_data_));
return SaveJSON(Map<std::string, NodeRef>(
meta_data_.begin(), meta_data_.end()));
}
/*! \return whether the meta data context is empty. */
......@@ -120,9 +121,9 @@ class TextMetaDataContext {
private:
/*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_;
/*! \brief map from meta data into its index */
std::unordered_map<NodeRef, int64_t, NodeHash, NodeEqual> meta_index_;
std::unordered_map<std::string, Array<NodeRef> > meta_data_;
/*! \brief map from meta data into its string representation */
std::unordered_map<NodeRef, std::string, NodeHash, NodeEqual> meta_repr_;
};
class TextPrinter :
......
......@@ -48,11 +48,11 @@ def test_meta_data():
f = relay.Function([x, w], z)
text = f.astext()
assert "channels=2" in text
assert "meta.Variable(id=0)" in text
assert "meta[Variable][0]" in text
show(text)
text = relay.const([1,2,3]).astext()
assert "meta.relay.Constant(id=0)" in text
assert "meta[relay.Constant][0]" in text
show(text)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment