Unverified Commit e0af5c20 by Tianqi Chen Committed by GitHub

[RELAY] TextPrinter: Use Map Format (#2553)

parent e2970b22
...@@ -42,11 +42,11 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO ...@@ -42,11 +42,11 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* we support a meta-data section in the text format. * we support a meta-data section in the text format.
* We allow the text format to refer to a node in the meta-data section. * We allow the text format to refer to a node in the meta-data section.
* *
* The meta-data section is a json serialized string of an Array<NodeRef>. * The meta-data section is a json serialized string of an Map<string, Array<NodeRef>>.
* Each element in the meta-data section can be referenced by the text format. * Each element in the meta-data section can be referenced by the text format.
* Each meta data node is printed in the following format. * Each meta data node is printed in the following format.
* *
* meta.<type-key-of-node>(<index-in-meta-section>) * meta[type-key-of-node>][<index-in-meta-section>]
* *
* Specifically, consider the following IR(constructed by python). * Specifically, consider the following IR(constructed by python).
* *
...@@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO ...@@ -63,7 +63,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* *
* \code * \code
* *
* fn (%x: Tensor[(meta.Variable(id=0),), float32]) { * fn (%x: Tensor[(meta[Variable][0],), float32]) {
* %x * %x
* } * }
* # Meta data section is a json-serialized string * # Meta data section is a json-serialized string
...@@ -74,7 +74,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO ...@@ -74,7 +74,7 @@ inline std::ostream& operator<<(std::ostream& os, const TextValue& val) { // NO
* *
* Note that we store tvm.var("n") in the meta data section. * Note that we store tvm.var("n") in the meta data section.
* Since it is stored in the index-0 in the meta-data section, * Since it is stored in the index-0 in the meta-data section,
* we print it as meta.Variable(0). * we print it as meta[Variable][0].
* *
* The text parser can recover this object by loading from the corresponding * The text parser can recover this object by loading from the corresponding
* location in the meta data section. * location in the meta data section.
...@@ -91,18 +91,18 @@ class TextMetaDataContext { ...@@ -91,18 +91,18 @@ class TextMetaDataContext {
* \return A string representation of the meta node. * \return A string representation of the meta node.
*/ */
std::string GetMetaNode(const NodeRef& node) { std::string GetMetaNode(const NodeRef& node) {
auto it = meta_repr_.find(node);
if (it != meta_repr_.end()) {
return it->second;
}
Array<NodeRef>& mvector =
meta_data_[node->type_key()];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
std::ostringstream os; std::ostringstream os;
auto it = meta_index_.find(node); os << "meta[" << node->type_key() << "][" << index << "]";
int64_t index; meta_repr_[node] = os.str();
if (it != meta_index_.end()) { return meta_repr_[node];
index = it->second;
} else {
index = static_cast<int64_t>(meta_data_.size());
meta_data_.push_back(node);
meta_index_[node] = index;
}
os << "meta." << node->type_key() << "(id=" << index << ")";
return os.str();
} }
/*! /*!
* \brief Get the metadata section in json format. * \brief Get the metadata section in json format.
...@@ -110,7 +110,8 @@ class TextMetaDataContext { ...@@ -110,7 +110,8 @@ class TextMetaDataContext {
*/ */
std::string GetMetaSection() const { std::string GetMetaSection() const {
if (meta_data_.size() == 0) return std::string(); if (meta_data_.size() == 0) return std::string();
return SaveJSON(Array<NodeRef>(meta_data_)); return SaveJSON(Map<std::string, NodeRef>(
meta_data_.begin(), meta_data_.end()));
} }
/*! \return whether the meta data context is empty. */ /*! \return whether the meta data context is empty. */
...@@ -120,9 +121,9 @@ class TextMetaDataContext { ...@@ -120,9 +121,9 @@ class TextMetaDataContext {
private: private:
/*! \brief additional metadata stored in TVM json format */ /*! \brief additional metadata stored in TVM json format */
std::vector<NodeRef> meta_data_; std::unordered_map<std::string, Array<NodeRef> > meta_data_;
/*! \brief map from meta data into its index */ /*! \brief map from meta data into its string representation */
std::unordered_map<NodeRef, int64_t, NodeHash, NodeEqual> meta_index_; std::unordered_map<NodeRef, std::string, NodeHash, NodeEqual> meta_repr_;
}; };
class TextPrinter : class TextPrinter :
......
...@@ -48,11 +48,11 @@ def test_meta_data(): ...@@ -48,11 +48,11 @@ def test_meta_data():
f = relay.Function([x, w], z) f = relay.Function([x, w], z)
text = f.astext() text = f.astext()
assert "channels=2" in text assert "channels=2" in text
assert "meta.Variable(id=0)" in text assert "meta[Variable][0]" in text
show(text) show(text)
text = relay.const([1,2,3]).astext() text = relay.const([1,2,3]).astext()
assert "meta.relay.Constant(id=0)" in text assert "meta[relay.Constant][0]" in text
show(text) show(text)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment