Commit 2973f8a6 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] parser/pretty printer roundtripping (#3536)

parent e5efc632
......@@ -224,6 +224,20 @@ def alpha_equal(lhs, rhs):
return bool(_make._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
......@@ -246,6 +260,23 @@ def graph_equal(lhs, rhs):
return bool(_make._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_graph_equal(lhs, rhs)
def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
......
......@@ -17,15 +17,20 @@
* under the License.
*/
// list = *, seq = ?
grammar Relay;
SEMVER: 'v0.0.3' ;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
LINE_COMMENT : '//' .*? '\n' -> skip ;
COMMENT : '/*' .*? '*/' -> skip ;
COMMENT : '/*' (COMMENT|.)*? '*/' -> skip;
WS : [ \t\n\r]+ -> skip;
LINE_COMMENT : '//' .*? '\n' -> skip;
fragment ESCAPED_QUOTE : '\\"';
QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"';
// operators
MUL: '*' ;
......@@ -39,18 +44,18 @@ GE: '>=' ;
EQ: '==' ;
NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
BOOL_LIT
: 'True'
| 'False'
;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
DATATYPE : 'int64';
// non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
......@@ -60,109 +65,99 @@ FLOAT : PREFLOAT 'f';
NAT: DIGIT+ ;
fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
fragment DIGIT: [0-9] ;
fragment LETTER: [a-zA-Z];
fragment DIGIT: [0-9];
METADATA: 'METADATA:' .*;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) EOF ;
prog: SEMVER (defn* | expr) METADATA? EOF ;
// option: 'set' ident BOOL_LIT ;
exprList: (expr (',' expr)*)?;
callList
: exprList # callNoAttr
| (expr ',')* attrSeq # callWithAttr
;
expr
// operators
: '(' expr ')' # parens
: '(' expr ')' # paren
| '{' expr '}' # paren
// function application
| expr '(' (expr (',' expr)*)? ')' # call
| expr '(' callList ')' # call
| '-' expr # neg
| expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp
// function definition
| func # funcExpr
// tuples and tensors
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr # let
| 'let' MUT? var '=' '{' expr '}' ';' expr # let
| 'let' var '=' expr ';' expr # let
// sugar for let %_ = expr; expr
| expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update
// | ident '=' expr # writeRef
// | expr '^' # readRef
| expr ';;' expr # let
| GRAPH_VAR '=' expr ';' expr # graph
| ident # identExpr
| scalar # scalarExpr
// | expr '.' NAT # project
// | 'debug' # debug
| meta # metaExpr
| QUOTED_STRING # stringExpr
;
func: 'fn' typeParamSeq? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParamSeq? '(' argList ')' ('->' type_)? body ;
func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
: varList # argNoAttr
| (var ',')* attrSeq # argWithAttr
;
varList: (var (',' var)*)? ;
var: ident (':' type_)? ;
varList: (var (',' var)*)?;
var: LOCAL_VAR (':' type_)?;
attrList: (attr (',' attr)*)? ;
attrSeq: attr (',' attr)*;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
// relations: 'where' relation (',' relation)* ;
// relation: ident '(' (type_ (',' type_)*)? ')' ;
typeParamSeq
typeParamList
: '[' ']'
| '[' ident (',' ident)* ']'
;
type_
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType
| 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType
// currently unused
// | typeIdent '[' (type_ (',' type_)*)? ']' # callType
| 'fn' typeParamSeq? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| NAT # intType
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType
| 'Tensor' '[' shapeList ',' type_ ']' # tensorType
| 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| NAT # intType
;
shapeSeq
: '(' ')'
| '(' shape ',' ')'
| '(' shape (',' shape)+ ')'
shapeList
: '(' shape (',' shape)+ ')'
| '(' ')'
| shape
;
meta : 'meta' '[' CNAME ']' '[' NAT ']';
shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
| NAT # intShape
: meta # metaShape
| '(' shape ')' # parensShape
| NAT # intShape
;
typeIdent : CNAME ;
typeIdent : CNAME;
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -19,11 +19,51 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#exprList.
def visitExprList(self, ctx:RelayParser.ExprListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#callNoAttr.
def visitCallNoAttr(self, ctx:RelayParser.CallNoAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#callWithAttr.
def visitCallWithAttr(self, ctx:RelayParser.CallWithAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#funcExpr.
def visitFuncExpr(self, ctx:RelayParser.FuncExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#metaExpr.
def visitMetaExpr(self, ctx:RelayParser.MetaExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graph.
def visitGraph(self, ctx:RelayParser.GraphContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#identExpr.
def visitIdentExpr(self, ctx:RelayParser.IdentExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#stringExpr.
def visitStringExpr(self, ctx:RelayParser.StringExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#call.
def visitCall(self, ctx:RelayParser.CallContext):
return self.visitChildren(ctx)
......@@ -39,13 +79,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#parens.
def visitParens(self, ctx:RelayParser.ParensContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#funcExpr.
def visitFuncExpr(self, ctx:RelayParser.FuncExprContext):
# Visit a parse tree produced by RelayParser#paren.
def visitParen(self, ctx:RelayParser.ParenContext):
return self.visitChildren(ctx)
......@@ -59,8 +94,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext):
# Visit a parse tree produced by RelayParser#projection.
def visitProjection(self, ctx:RelayParser.ProjectionContext):
return self.visitChildren(ctx)
......@@ -69,11 +104,6 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graph.
def visitGraph(self, ctx:RelayParser.GraphContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#binOp.
def visitBinOp(self, ctx:RelayParser.BinOpContext):
return self.visitChildren(ctx)
......@@ -89,8 +119,13 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#argList.
def visitArgList(self, ctx:RelayParser.ArgListContext):
# Visit a parse tree produced by RelayParser#argNoAttr.
def visitArgNoAttr(self, ctx:RelayParser.ArgNoAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#argWithAttr.
def visitArgWithAttr(self, ctx:RelayParser.ArgWithAttrContext):
return self.visitChildren(ctx)
......@@ -104,8 +139,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#attrList.
def visitAttrList(self, ctx:RelayParser.AttrListContext):
# Visit a parse tree produced by RelayParser#attrSeq.
def visitAttrSeq(self, ctx:RelayParser.AttrSeqContext):
return self.visitChildren(ctx)
......@@ -114,8 +149,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParamSeq.
def visitTypeParamSeq(self, ctx:RelayParser.TypeParamSeqContext):
# Visit a parse tree produced by RelayParser#typeParamList.
def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
return self.visitChildren(ctx)
......@@ -149,8 +184,18 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#shapeSeq.
def visitShapeSeq(self, ctx:RelayParser.ShapeSeqContext):
# Visit a parse tree produced by RelayParser#shapeList.
def visitShapeList(self, ctx:RelayParser.ShapeListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#meta.
def visitMeta(self, ctx:RelayParser.MetaContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#metaShape.
def visitMetaShape(self, ctx:RelayParser.MetaShapeContext):
return self.visitChildren(ctx)
......
......@@ -66,34 +66,34 @@ def conv2d(data,
weight : tvm.relay.Expr
The weight expressions.
strides : tuple of int, optional
strides : Optional[Tuple[int]]
The strides of convolution.
padding : tuple of int, optional
padding : Optional[Tuple[int]]
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
dilation : Optional[Tuple[int]]
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
groups : Optional[int]
Number of groups for grouped convolution.
channels : int, optional
channels : Optional[int]
Number of output channels of this convolution.
kernel_size : tuple of int, optional
kernel_size : Optional[Tuple[int]]
The spatial of the convolution kernel.
data_layout : str, optional
data_layout : Optional[str]
Layout of the input.
kernel_layout : str, optional
kernel_layout : Optional[str]
Layout of the weight.
out_layout : str, optional
out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
out_dtype : Optional[str]
Specifies the output data type for mixed precision conv2d.
Returns
......@@ -691,8 +691,30 @@ def dropout(data, rate=0.5):
result : tvm.relay.Expr
The result of dropout
"""
result = _make.dropout(data, rate)
return TupleWrapper(result, 2)[0]
return TupleWrapper(dropout_raw(data, rate), 2)[0]
def dropout_raw(data, rate=0.5):
"""Applies the dropout operation to the input array.
During training, each element of the input is set to zero with
probability ``p``. The whole array is rescaled by ``1/(1-p)``
to keep the expected sum of the input unchanged.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
rate : float, optional (default=0.5)
The probability for an element to be reset to 0.
Returns
-------
result : tvm.relay.Expr
The result of dropout
"""
return _make.dropout(data, rate)
def batch_norm(data,
......
......@@ -23,4 +23,7 @@ from .. import register_func
def fromtext(data, source_name=None):
"""Parse a Relay program."""
from tvm.relay import _parser
return _parser.fromtext(data, source_name)
x = _parser.fromtext(data + "\n", source_name)
if x is None:
raise Exception("cannot parse: ", data)
return x
......@@ -42,7 +42,7 @@ def _make_dense_block(data, num_layers, bn_size, growth_rate, index):
layer_out = data
for i in range(num_layers):
layer_out = _make_dense_layer(layer_out, growth_rate, bn_size,
"(%s, %s)" % (index, i))
"%s_%s" % (index, i))
return layer_out
def _make_transition(data, num_output_features, index):
......
......@@ -29,7 +29,7 @@ class Type(RelayNode):
"""Compare two Relay types for structural equivalence using
alpha equivalence.
"""
return bool(_make._type_alpha_equal(self, other))
return bool(_make._alpha_equal(self, other))
def __ne__(self, other):
return not self.__eq__(other)
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -18,7 +18,7 @@
*/
/*!
* Copyright (c) 2018 by Contributors
* Copyright (c) 2019 by Contributors
* \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes.
*/
......@@ -27,9 +27,10 @@
#include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "type_functor.h"
#include "../../lang/attr_functor.h"
namespace tvm {
namespace relay {
......@@ -40,8 +41,8 @@ class AlphaEqualHandler:
public ExprFunctor<bool(const Expr&, const Expr&)>,
public PatternFunctor<bool(const Pattern&, const Pattern&)> {
public:
explicit AlphaEqualHandler(bool map_free_var)
: map_free_var_(map_free_var) { }
explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
: map_free_var_(map_free_var), assert_mode_(assert_mode) { }
/*!
* Check equality of two nodes.
......@@ -76,6 +77,9 @@ class AlphaEqualHandler:
return AttrEqual(lhs, rhs);
}
bool DoubleEqual(double l, double r) {
return true;
}
/*!
* Check equality of two attributes.
* \param lhs The left hand operand.
......@@ -83,18 +87,28 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
if (&lhs == &rhs) return true;
auto lhsd = lhs.as<DictAttrsNode>();
if (lhsd) {
auto rhsd = lhs.as<DictAttrsNode>();
if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false;
for (const auto& k : lhsd->dict) {
if (!Equal(k.second, rhsd->dict[k.first])) return false;
auto compute = [&]() {
if (&lhs == &rhs) return true;
if (auto lhsd = lhs.as<DictAttrsNode>()) {
auto rhsd = lhs.as<DictAttrsNode>();
if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false;
for (const auto& k : lhsd->dict) {
if (!Equal(k.second, rhsd->dict[k.first])) return false;
}
return true;
}
return true;
}
return AttrsEqualHandler::Equal(lhs, rhs);
if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
auto rhsbn = rhs.as<BatchNormAttrs>();
if (!rhsbn) return false;
return (lhsbn->axis == rhsbn->axis)
&& DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
&& (lhsbn->center == rhsbn->center)
&& (lhsbn->scale == rhsbn->scale);
}
return AttrsEqualHandler::Equal(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
}
/*!
* Check equality of two types.
......@@ -107,6 +121,13 @@ class AlphaEqualHandler:
if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs);
}
bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) {
if (assert_mode_) {
CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true);
}
return result;
}
/*!
* Check equality of two expressions.
*
......@@ -120,18 +141,21 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool ExprEqual(const Expr& lhs, const Expr& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
}
if (this->VisitExpr(lhs, rhs)) {
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
auto compute = [&]() {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
auto it = equal_map_.find(lhs);
if (it != equal_map_.end()) {
return it->second.same_as(rhs);
}
if (this->VisitExpr(lhs, rhs)) {
equal_map_[lhs] = rhs;
return true;
} else {
return false;
}
};
return Compare(compute(), lhs, rhs);
}
protected:
......@@ -516,32 +540,41 @@ class AlphaEqualHandler:
private:
// whether to map open terms.
bool map_free_var_;
// if in assert mode, must return true, and will throw error otherwise.
bool assert_mode_;
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
};
bool AlphaEqual(const Type& lhs, const Type& rhs) {
return AlphaEqualHandler(false).TypeEqual(lhs, rhs);
return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
}
bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false).ExprEqual(lhs, rhs);
return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
}
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(false).Equal(a, b);
return AlphaEqualHandler(false, false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body_typed<bool(Type, Type)>([](Type a, Type b) {
return AlphaEqualHandler(false).TypeEqual(a, b);
TVM_REGISTER_API("relay._make._assert_alpha_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal";
});
TVM_REGISTER_API("relay._make._graph_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(true).Equal(a, b);
return AlphaEqualHandler(true, false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal";
});
} // namespace relay
......
......@@ -89,7 +89,7 @@ std::string Doc::str() {
return os.str();
}
Doc PrintVec(const std::vector<Doc>& vec, const Doc& sep) {
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq;
if (vec.size() != 0) {
seq = vec[0];
......
......@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
......@@ -46,7 +46,11 @@ using DocAtom = std::shared_ptr<DocAtomNode>;
struct TextNode : DocAtomNode {
std::string str;
explicit TextNode(const std::string& str) : str(str) {}
explicit TextNode(const std::string& str) : str(str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}
};
struct LineNode : DocAtomNode {
......@@ -91,8 +95,8 @@ class Doc {
// DSL functions
// Render vectors of docs with a separator. e.g. PrintVec([1, 2, 3], f) -> 1f2f3
Doc PrintVec(const std::vector<Doc>& vec, const Doc& sep = Doc(", "));
// Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3
Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep = Doc(", "));
// Print a constant bool value.
Doc PrintBool(bool value);
// Print a data type.
......@@ -116,7 +120,8 @@ Doc PrintConstScalar(DataType dtype, const T* data) {
} else if (dtype == Bool()) {
return PrintBool(data[0] != 0);
} else {
os << dtype << "(" << data[0] << ")";
// todo(@M.K.) this is unsafe. fix.
os << data[0];
}
return Doc(os.str());
}
......
......@@ -32,6 +32,7 @@
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/
#include <dmlc/json.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
......@@ -43,6 +44,17 @@
namespace tvm {
namespace relay {
Doc Brace(const Doc& d,
const std::string& open = "{",
const std::string& close = "}",
int indent = 2) {
Doc doc;
doc << open;
doc << Indent(indent, PrintNewLine() << d) << PrintNewLine();
doc << close;
return doc;
}
/*!
* \brief Meta data context for PrettyPrinter.
*
......@@ -108,8 +120,10 @@ class TextMetaDataContext {
if (it != meta_repr_.end()) {
return it->second;
}
std::string type_key = node->type_key();
CHECK(!type_key.empty());
Array<NodeRef>& mvector =
meta_data_[node->type_key()];
meta_data_[type_key];
int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node);
Doc doc;
......@@ -117,14 +131,18 @@ class TextMetaDataContext {
meta_repr_[node] = doc;
return meta_repr_[node];
}
Doc PrintKeyValue(const std::string& str, const Doc& v) const {
return Doc("\"") << str << "\": " << v;
}
/*!
* \brief Get the metadata section in json format.
* \return the meta data string.
*/
std::string GetMetaSection() const {
if (meta_data_.size() == 0) return std::string();
return SaveJSON(Map<std::string, NodeRef>(
meta_data_.begin(), meta_data_.end()));
Doc GetMetaSection() const {
if (meta_data_.size() == 0) return Doc();
return Doc(SaveJSON(Map<std::string, NodeRef>(meta_data_.begin(), meta_data_.end())));
}
/*! \return whether the meta data context is empty. */
......@@ -172,12 +190,11 @@ class PrettyPrinter :
}
// indent a new body
// TODO(jmp): indent should be an instance variable of the printer
Doc PrintBody(const NodeRef& node, int indent = 2) {
Doc doc;
Doc body;
doc << "{";
doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n";
doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine();
doc << "}";
return doc;
}
......@@ -203,13 +220,12 @@ class PrettyPrinter :
Doc doc;
doc << PrintScope(node);
if (!meta_.empty()) {
doc << PrintNewLine();
if (show_meta_data_) {
std::string meta_json = meta_.GetMetaSection();
// append meta data in the end.
doc << "\n" << "/* meta data */" << "\n" << meta_json;
doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection();
} else {
doc << "\n"
<< "// meta data omitted. you can use show_meta_data=True to include meta data";
doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
}
}
return doc;
......@@ -361,7 +377,7 @@ class PrettyPrinter :
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n";
printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << "}";
} else {
printed_expr = VisitExpr(expr);
......@@ -373,7 +389,7 @@ class PrettyPrinter :
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n";
doc_stack_.back() << "free_var " << printed_expr << PrintNewLine();
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
......@@ -422,7 +438,7 @@ class PrettyPrinter :
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << PrintVec(fields);
doc << "(" << PrintSep(fields);
// conform to python tuple format (1,)
if (op->fields.size() == 1) {
doc << ",";
......@@ -460,31 +476,31 @@ class PrettyPrinter :
}
Doc PrintFunc(const Doc& prefix, const Function& fn) {
Doc doc;
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
}
doc << PrintVec(type_params);
doc << ">";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << PrintVec(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
Doc doc;
doc << prefix;
if (fn->type_params.size() > 0) {
doc << "<";
std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv));
}
doc << PrintBody(fn->body);
return doc;
doc << PrintSep(type_params);
doc << ">";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << PrintSep(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
doc << PrintBody(fn->body);
return doc;
}
Doc PrintMod(const Module& mod) {
......@@ -493,13 +509,13 @@ class PrettyPrinter :
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);
std::ostringstream os;
if (counter++ != 0) {
doc << "\n";
doc << PrintNewLine();
}
std::ostringstream os;
os << "def @" << kv.first->name_hint;
doc << PrintFunc(Doc(os.str()), kv.second);
doc << "\n";
doc << PrintNewLine();
}
return doc;
}
......@@ -528,7 +544,7 @@ class PrettyPrinter :
args.push_back(d);
}
doc << Print(op->op);
return doc << "(" << PrintVec(args) << ")";
return doc << "(" << PrintSep(args) << ")";
}
Doc VisitExpr_(const RefCreateNode* op) final {
......@@ -558,7 +574,7 @@ class PrettyPrinter :
clauses.push_back(clause_doc << Print(clause->lhs) << " -> "
<< Print(clause->rhs));
}
doc << Indent(2, body << "\n" << PrintVec(clauses, Doc("\n"))) << "\n";
doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine();
doc << "}";
return doc;
}
......@@ -570,7 +586,7 @@ class PrettyPrinter :
for (const auto& pat : p->patterns) {
pats.push_back(Print(pat));
}
return doc << PrintVec(pats) << ")";
return doc << PrintSep(pats) << ")";
}
Doc VisitPattern_(const PatternVarNode* pv) final {
......@@ -617,7 +633,7 @@ class PrettyPrinter :
args.push_back(PrintType(t, false));
}
doc << "[";
doc << PrintVec(args);
doc << PrintSep(args);
doc << "]";
return doc;
}
......@@ -633,11 +649,7 @@ class PrettyPrinter :
for (NodeRef shape : node->shape) {
shapes.push_back(PrintAttr(shape));
}
doc << PrintVec(shapes);
// conform to python tuple format (1,)
if (node->shape.size() == 1) {
doc << ",";
}
doc << PrintSep(shapes);
return doc << "), " << PrintDType(node->dtype) << "]";
}
......@@ -647,7 +659,7 @@ class PrettyPrinter :
fields.push_back(Print(field));
}
Doc doc;
doc << "(" << PrintVec(fields);
doc << "(" << PrintSep(fields);
// conform to python tuple format (1,)
if (node->fields.size() == 1) {
doc << ",";
......@@ -664,14 +676,14 @@ class PrettyPrinter :
for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param));
}
doc << PrintVec(type_params);
doc << PrintSep(type_params);
doc << ">";
}
std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) {
arg_types.push_back(Print(arg_type));
}
return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type);
return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
}
Doc VisitType_(const RefTypeNode* node) final {
......@@ -710,7 +722,7 @@ class PrettyPrinter :
for (NodePtr<Node> val : op->data) {
arr_vals.push_back(PrintAttr(NodeRef(val)));
}
doc << PrintVec(arr_vals);
doc << PrintSep(arr_vals);
doc << "]";
return doc;
}
......@@ -771,7 +783,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
}
void Visit(const char* key, double* value) final {
PrintKV(key, *value);
Doc doc;
doc << key << "=" << *value << "f";
docs->push_back(doc);
}
void Visit(const char* key, int64_t* value) final {
PrintKV(key, *value);
......@@ -843,7 +857,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.3" << "\n"
doc << "v0.0.3" << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
......
......@@ -16,7 +16,7 @@
# under the License.
import tvm
from tvm import relay
from tvm.relay.analysis import alpha_equal
from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from nose.tools import nottest, raises
from numpy import isclose
from typing import Union
......@@ -60,12 +60,9 @@ TYPES = {
"float16x4",
}
def assert_alpha_equal(a, b):
if not alpha_equal(a, b):
raise Exception("lhs is: ", str(a), "rhs is: ", str(b))
def roundtrip(expr):
assert_alpha_equal(relay.fromtext(str(expr)), expr)
x = relay.fromtext(str(expr))
assert_alpha_equal(x, expr)
def parse_text(code):
......@@ -112,6 +109,16 @@ def test_comments():
UNIT
)
assert parses_as(
"""
/* This is a block comment!
/*Block comment is recursive!*/
*/
()
""",
UNIT
)
def test_int_literal():
assert isinstance(parse_text("1"), relay.Constant)
......@@ -224,7 +231,7 @@ def test_let():
def test_seq():
assert parses_as(
"(); ()",
"();; ()",
relay.Let(
_,
UNIT,
......@@ -538,7 +545,7 @@ def test_tensor_type():
)
assert parses_as(
"let %_ : Tensor[(1,), float32] = (); ()",
"let %_ : Tensor[(1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1,), "float32")),
UNIT,
......
......@@ -15,14 +15,27 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import relay
import tvm.relay.testing
import numpy as np
from tvm import relay
from tvm.relay import Expr
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars
do_print = [False]
SEMVER = "v0.0.3\n"
def astext(p, graph_equal=False):
txt = p.astext()
if isinstance(p, Expr) and free_vars(p):
return txt
x = relay.fromtext(txt)
if graph_equal:
assert_graph_equal(x, p)
else:
assert_alpha_equal(x, p)
return txt
def show(text):
if do_print[0]:
print("---------------------------")
......@@ -35,8 +48,8 @@ def test_func():
z = relay.add(x, one)
z = relay.add(z, z)
f = relay.Function([x, y], z)
show(z.astext())
show(f.astext())
show(astext(z))
show(astext(f))
def test_env():
......@@ -47,7 +60,7 @@ def test_env():
f = relay.Function([x, y], z)
env = relay.Module()
env["myf"] = f
text = env.astext()
text = astext(env)
assert "def @myf" in text
assert "def @myf" in str(env)
assert "add(%0, %0) /* ty=float32 */" in text
......@@ -65,7 +78,7 @@ def test_meta_data():
padding=(1, 1),
channels=2)
f = relay.Function([x, w], z)
text = f.astext()
text = astext(f, graph_equal=True)
text_no_meta = str(f)
assert "channels=2" in text
assert "channels=2" in text_no_meta
......@@ -73,25 +86,22 @@ def test_meta_data():
assert "meta[Variable][0]" in text_no_meta
assert "type_key" in text
assert "type_key" not in text_no_meta
show(text)
show(f)
text = relay.const([1,2,3]).astext()
text = astext(relay.const([1,2,3]))
assert "meta[relay.Constant][0]" in text
show(text)
def test_call_attrs():
x = relay.var("x")
# non default args
z = relay.nn.softmax(x, axis=2)
assert "axis=2" in z.astext()
assert "axis=2" in astext(z)
# default args
z = relay.nn.softmax(x)
assert "softmax(%x)" in z.astext()
assert "softmax(%x)" in astext(z)
# non default args
z = relay.expand_dims(x, axis=2, num_newaxis=2)
assert "num_newaxis=2" in z.astext()
assert "num_newaxis=2" in astext(z)
def test_let_if_scope():
......@@ -111,68 +121,72 @@ def test_let_if_scope():
result = sb.get()
f = relay.Function([x, y, cond], result)
text = f.astext()
text = astext(f)
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())
show(astext(f))
def test_variable_name():
# avoid pure number even if the namehint is pure number
v1 = relay.var("1")
assert "%v1" in v1.astext()
assert "%v1" in astext(v1)
def test_mlp():
net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
net.astext()
astext(net)
def test_resnet():
net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
net.astext()
astext(net)
def test_mobilenet():
net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
net.astext()
astext(net)
def test_dqn():
net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
net.astext()
astext(net)
def test_dcgan():
net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
net.astext()
astext(net)
def test_lstm():
net, params = tvm.relay.testing.lstm.get_workload(1, 1)
astext(net)
net, params = tvm.relay.testing.lstm.get_workload(4, 4)
net.astext()
astext(net)
def test_inception_v3():
net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
net.astext()
astext(net)
def test_squeezenet():
for version in ['1.0', '1.1']:
net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
net.astext()
astext(net)
def test_vgg():
net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
net.astext()
astext(net)
def test_densenet():
net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
net.astext()
astext(net)
def test_call_node_order():
x = relay.var("x")
y = relay.var("y")
assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \
prog = relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])])
assert astext(prog) == SEMVER + \
("%0 = fn (%y) {\n"
" %y\n"
"};\n"
......@@ -185,17 +199,25 @@ def test_call_node_order():
def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x")
assert relay.Let(x, tup, tup).astext() == SEMVER + \
assert astext(relay.Let(x, tup, tup)) == SEMVER + \
("%0 = (0, 0);\n"
"let %x = %0;\n"
"%0")
assert relay.Let(x, tup, x).astext() == SEMVER + \
assert astext(relay.Let(x, tup, x)) == SEMVER + \
("let %x = (0, 0);\n"
"%x")
def test_zeros():
x = relay.op.zeros([], "float32")
astext(x)
if __name__ == "__main__":
do_print[0] = True
test_lstm()
test_zeros()
test_meta_data()
test_let_inlining()
test_resnet()
test_mobilenet()
test_mlp()
......@@ -207,9 +229,7 @@ if __name__ == "__main__":
test_densenet()
test_func()
test_env()
test_meta_data()
test_call_attrs()
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_let_inlining()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment