Commit fed1c08e by Josh Pollock Committed by Tianqi Chen

[Relay][Text Format] Fix Pretty Printing Annotations (#3041)

parent cdc9e85c
......@@ -156,14 +156,19 @@ class PrettyPrinter :
*/
Doc PrintOptionalInfo(const Expr& expr) {
Doc doc;
// additional information in comment.
if (annotate_ != nullptr) {
return doc << " /* " << annotate_(expr) << " */";
} else if (expr->checked_type_.defined()) {
return doc << " /* ty=" << Print(expr->checked_type()) << " */";
// default annotations
if (annotate_ == nullptr) {
if ((expr.as<ConstantNode>() || expr.as<CallNode>()) && expr->checked_type_.defined()) {
doc << " /* ty=" << Print(expr->checked_type()) << " */";
}
} else {
return doc;
std::string annotated_expr = annotate_(expr);
if (annotated_expr != "") {
doc << annotated_expr;
}
}
return doc;
}
// indent a new body
......@@ -361,9 +366,7 @@ class PrettyPrinter :
printed_expr = VisitExpr(expr);
}
if (expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}
printed_expr << PrintOptionalInfo(expr);
// add expr to doc
if (expr.as<VarNode>()) {
......@@ -409,8 +412,7 @@ class PrettyPrinter :
}
// default fall-back, record it as meta node.
Doc doc;
return doc << Print(GetRef<NodeRef>(op), true)
<< PrintOptionalInfo(GetRef<Expr>(op));
return doc << Print(GetRef<NodeRef>(op), true);
}
Doc VisitExpr_(const TupleNode* op) final {
......
......@@ -52,7 +52,7 @@ def test_env():
assert "def @myf" in str(env)
assert "add(%0, %0) /* ty=float32 */" in text
assert "add(%0, %0) /* ty=float32 */" in str(env)
show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
show(env.astext(annotate=lambda x: str(x.checked_type.dtype) if type(x) == relay.Call else ""))
show(text)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment