Commit fed1c08e by Josh Pollock Committed by Tianqi Chen

[Relay][Text Format] Fix Pretty Printing Annotations (#3041)

parent cdc9e85c
...@@ -156,14 +156,19 @@ class PrettyPrinter : ...@@ -156,14 +156,19 @@ class PrettyPrinter :
*/ */
Doc PrintOptionalInfo(const Expr& expr) { Doc PrintOptionalInfo(const Expr& expr) {
Doc doc; Doc doc;
// additional information in comment. // default annotations
if (annotate_ != nullptr) { if (annotate_ == nullptr) {
return doc << " /* " << annotate_(expr) << " */"; if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
} else if (expr->checked_type_.defined()) { doc << " /* ty=" << Print(expr->checked_type()) << " */";
return doc << " /* ty=" << Print(expr->checked_type()) << " */"; }
} else { } else {
return doc; std::string annotated_expr = annotate_(expr);
if (annotated_expr != "") {
doc << annotated_expr;
}
} }
return doc;
} }
// indent a new body // indent a new body
...@@ -361,9 +366,7 @@ class PrettyPrinter : ...@@ -361,9 +366,7 @@ class PrettyPrinter :
printed_expr = VisitExpr(expr); printed_expr = VisitExpr(expr);
} }
if (expr.as<CallNode>()) { printed_expr << PrintOptionalInfo(expr);
printed_expr << PrintOptionalInfo(expr);
}
// add expr to doc // add expr to doc
if (expr.as<VarNode>()) { if (expr.as<VarNode>()) {
...@@ -409,8 +412,7 @@ class PrettyPrinter : ...@@ -409,8 +412,7 @@ class PrettyPrinter :
} }
// default fall-back, record it as meta node. // default fall-back, record it as meta node.
Doc doc; Doc doc;
return doc << Print(GetRef<NodeRef>(op), true) return doc << Print(GetRef<NodeRef>(op), true);
<< PrintOptionalInfo(GetRef<Expr>(op));
} }
Doc VisitExpr_(const TupleNode* op) final { Doc VisitExpr_(const TupleNode* op) final {
......
...@@ -52,7 +52,7 @@ def test_env(): ...@@ -52,7 +52,7 @@ def test_env():
assert "def @myf" in str(env) assert "def @myf" in str(env)
assert "add(%0, %0) /* ty=float32 */" in text assert "add(%0, %0) /* ty=float32 */" in text
assert "add(%0, %0) /* ty=float32 */" in str(env) assert "add(%0, %0) /* ty=float32 */" in str(env)
show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(env.astext(annotate=lambda x: str(x.checked_type.dtype) if type(x) == relay.Call else ""))
show(text) show(text)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment