Commit 2973f8a6 by 雾雨魔理沙 Committed by Tianqi Chen

[Relay] parser/pretty printer roundtripping (#3536)

parent e5efc632
...@@ -224,6 +224,20 @@ def alpha_equal(lhs, rhs): ...@@ -224,6 +224,20 @@ def alpha_equal(lhs, rhs):
return bool(_make._alpha_equal(lhs, rhs)) return bool(_make._alpha_equal(lhs, rhs))
def assert_alpha_equal(lhs, rhs):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_alpha_equal(lhs, rhs)
def graph_equal(lhs, rhs): def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence. """Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that The difference between this and alpha-equality is that
...@@ -246,6 +260,23 @@ def graph_equal(lhs, rhs): ...@@ -246,6 +260,23 @@ def graph_equal(lhs, rhs):
return bool(_make._graph_equal(lhs, rhs)) return bool(_make._graph_equal(lhs, rhs))
def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.
Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.
rhs : tvm.relay.Expr
One of the input Expression.
"""
_make._assert_graph_equal(lhs, rhs)
def collect_device_info(expr): def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device """Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators. ids are propagated from the `device_copy` operators.
......
...@@ -17,15 +17,20 @@ ...@@ -17,15 +17,20 @@
* under the License. * under the License.
*/ */
// list = *, seq = ?
grammar Relay; grammar Relay;
SEMVER: 'v0.0.3' ; SEMVER: 'v0.0.3' ;
// Lexing // Lexing
// comments // comments
WS : [ \t\n\r]+ -> skip ; COMMENT : '/*' (COMMENT|.)*? '*/' -> skip;
LINE_COMMENT : '//' .*? '\n' -> skip ; WS : [ \t\n\r]+ -> skip;
COMMENT : '/*' .*? '*/' -> skip ; LINE_COMMENT : '//' .*? '\n' -> skip;
fragment ESCAPED_QUOTE : '\\"';
QUOTED_STRING : '"' ( ESCAPED_QUOTE | ~('\n'|'\r') )*? '"';
// operators // operators
MUL: '*' ; MUL: '*' ;
...@@ -39,18 +44,18 @@ GE: '>=' ; ...@@ -39,18 +44,18 @@ GE: '>=' ;
EQ: '==' ; EQ: '==' ;
NE: '!=' ; NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
BOOL_LIT BOOL_LIT
: 'True' : 'True'
| 'False' | 'False'
; ;
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)*;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
DATATYPE : 'int64';
// non-negative floats // non-negative floats
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4 fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
...@@ -60,109 +65,99 @@ FLOAT : PREFLOAT 'f'; ...@@ -60,109 +65,99 @@ FLOAT : PREFLOAT 'f';
NAT: DIGIT+ ; NAT: DIGIT+ ;
fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...] fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ; fragment LETTER: [a-zA-Z];
fragment LETTER: [a-zA-Z] ; fragment DIGIT: [0-9];
fragment DIGIT: [0-9] ;
METADATA: 'METADATA:' .*;
// Parsing // Parsing
// A Relay program is a list of global definitions or an expression. // A Relay program is a list of global definitions or an expression.
prog: SEMVER (defn* | expr) EOF ; prog: SEMVER (defn* | expr) METADATA? EOF ;
// option: 'set' ident BOOL_LIT ; // option: 'set' ident BOOL_LIT ;
exprList: (expr (',' expr)*)?;
callList
: exprList # callNoAttr
| (expr ',')* attrSeq # callWithAttr
;
expr expr
// operators // operators
: '(' expr ')' # parens : '(' expr ')' # paren
| '{' expr '}' # paren
// function application // function application
| expr '(' (expr (',' expr)*)? ')' # call | expr '(' callList ')' # call
| '-' expr # neg | '-' expr # neg
| expr op=('*'|'/') expr # binOp | expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp | expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp | expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp | expr op=('=='|'!=') expr # binOp
// function definition // function definition
| func # funcExpr | func # funcExpr
// tuples and tensors // tuples and tensors
| '(' ')' # tuple | '(' ')' # tuple
| '(' expr ',' ')' # tuple | '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple | '(' expr (',' expr)+ ')' # tuple
| expr '.' NAT # projection
| '[' (expr (',' expr)*)? ']' # tensor | '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse | 'if' '(' expr ')' body 'else' body # ifElse
// sequencing // sequencing
| 'let' MUT? var '=' expr ';' expr # let | 'let' var '=' expr ';' expr # let
| 'let' MUT? var '=' '{' expr '}' ';' expr # let
// sugar for let %_ = expr; expr // sugar for let %_ = expr; expr
| expr ';' expr # let | expr ';;' expr # let
| ident '=' expr ';' expr # graph | GRAPH_VAR '=' expr ';' expr # graph
// mutable update
// | ident '=' expr # writeRef
// | expr '^' # readRef
| ident # identExpr | ident # identExpr
| scalar # scalarExpr | scalar # scalarExpr
// | expr '.' NAT # project | meta # metaExpr
// | 'debug' # debug | QUOTED_STRING # stringExpr
; ;
func: 'fn' typeParamSeq? '(' argList ')' ('->' type_)? body ; func: 'fn' typeParamList? '(' argList ')' ('->' type_)? body ;
defn: 'def' ident typeParamSeq? '(' argList ')' ('->' type_)? body ; defn: 'def' ident typeParamList? '(' argList ')' ('->' type_)? body ;
argList argList
: varList : varList # argNoAttr
| attrList | (var ',')* attrSeq # argWithAttr
| varList ',' attrList
; ;
varList: (var (',' var)*)? ; varList: (var (',' var)*)?;
var: ident (':' type_)? ; var: LOCAL_VAR (':' type_)?;
attrList: (attr (',' attr)*)? ; attrSeq: attr (',' attr)*;
attr: CNAME '=' expr ; attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations typeParamList
// returnAnno: (ident ':')? type_ ;
// relations: 'where' relation (',' relation)* ;
// relation: ident '(' (type_ (',' type_)*)? ')' ;
typeParamSeq
: '[' ']' : '[' ']'
| '[' ident (',' ident)* ']' | '[' ident (',' ident)* ']'
; ;
type_ type_
: '(' ')' # tupleType : '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType | '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType | '(' type_ (',' type_)+ ')' # tupleType
| typeIdent # typeIdentType | typeIdent # typeIdentType
| 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType | 'Tensor' '[' shapeList ',' type_ ']' # tensorType
// currently unused | 'fn' typeParamList? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
// | typeIdent '[' (type_ (',' type_)*)? ']' # callType | '_' # incompleteType
| 'fn' typeParamSeq? '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | NAT # intType
| '_' # incompleteType
| NAT # intType
; ;
shapeSeq shapeList
: '(' ')' : '(' shape (',' shape)+ ')'
| '(' shape ',' ')' | '(' ')'
| '(' shape (',' shape)+ ')' | shape
; ;
meta : 'meta' '[' CNAME ']' '[' NAT ']';
shape shape
: '(' shape ')' # parensShape : meta # metaShape
// | type_ op=('*'|'/') type_ # binOpType | '(' shape ')' # parensShape
// | type_ op=('+'|'-') type_ # binOpType | NAT # intShape
| NAT # intShape
; ;
typeIdent : CNAME ; typeIdent : CNAME;
// int8, int16, int32, int64 // int8, int16, int32, int64
// uint8, uint16, uint32, uint64 // uint8, uint16, uint32, uint64
// float16, float32, float64 // float16, float32, float64
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -19,11 +19,51 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -19,11 +19,51 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#exprList.
def visitExprList(self, ctx:RelayParser.ExprListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#callNoAttr.
def visitCallNoAttr(self, ctx:RelayParser.CallNoAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#callWithAttr.
def visitCallWithAttr(self, ctx:RelayParser.CallWithAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#funcExpr.
def visitFuncExpr(self, ctx:RelayParser.FuncExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#metaExpr.
def visitMetaExpr(self, ctx:RelayParser.MetaExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor.
def visitTensor(self, ctx:RelayParser.TensorContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graph.
def visitGraph(self, ctx:RelayParser.GraphContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#identExpr. # Visit a parse tree produced by RelayParser#identExpr.
def visitIdentExpr(self, ctx:RelayParser.IdentExprContext): def visitIdentExpr(self, ctx:RelayParser.IdentExprContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#stringExpr.
def visitStringExpr(self, ctx:RelayParser.StringExprContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#call. # Visit a parse tree produced by RelayParser#call.
def visitCall(self, ctx:RelayParser.CallContext): def visitCall(self, ctx:RelayParser.CallContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -39,13 +79,8 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -39,13 +79,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#parens. # Visit a parse tree produced by RelayParser#paren.
def visitParens(self, ctx:RelayParser.ParensContext): def visitParen(self, ctx:RelayParser.ParenContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#funcExpr.
def visitFuncExpr(self, ctx:RelayParser.FuncExprContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -59,8 +94,8 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -59,8 +94,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#tensor. # Visit a parse tree produced by RelayParser#projection.
def visitTensor(self, ctx:RelayParser.TensorContext): def visitProjection(self, ctx:RelayParser.ProjectionContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -69,11 +104,6 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -69,11 +104,6 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#graph.
def visitGraph(self, ctx:RelayParser.GraphContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#binOp. # Visit a parse tree produced by RelayParser#binOp.
def visitBinOp(self, ctx:RelayParser.BinOpContext): def visitBinOp(self, ctx:RelayParser.BinOpContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -89,8 +119,13 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -89,8 +119,13 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#argList. # Visit a parse tree produced by RelayParser#argNoAttr.
def visitArgList(self, ctx:RelayParser.ArgListContext): def visitArgNoAttr(self, ctx:RelayParser.ArgNoAttrContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#argWithAttr.
def visitArgWithAttr(self, ctx:RelayParser.ArgWithAttrContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -104,8 +139,8 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -104,8 +139,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#attrList. # Visit a parse tree produced by RelayParser#attrSeq.
def visitAttrList(self, ctx:RelayParser.AttrListContext): def visitAttrSeq(self, ctx:RelayParser.AttrSeqContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -114,8 +149,8 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -114,8 +149,8 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#typeParamSeq. # Visit a parse tree produced by RelayParser#typeParamList.
def visitTypeParamSeq(self, ctx:RelayParser.TypeParamSeqContext): def visitTypeParamList(self, ctx:RelayParser.TypeParamListContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
...@@ -149,8 +184,18 @@ class RelayVisitor(ParseTreeVisitor): ...@@ -149,8 +184,18 @@ class RelayVisitor(ParseTreeVisitor):
return self.visitChildren(ctx) return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#shapeSeq. # Visit a parse tree produced by RelayParser#shapeList.
def visitShapeSeq(self, ctx:RelayParser.ShapeSeqContext): def visitShapeList(self, ctx:RelayParser.ShapeListContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#meta.
def visitMeta(self, ctx:RelayParser.MetaContext):
return self.visitChildren(ctx)
# Visit a parse tree produced by RelayParser#metaShape.
def visitMetaShape(self, ctx:RelayParser.MetaShapeContext):
return self.visitChildren(ctx) return self.visitChildren(ctx)
......
...@@ -66,34 +66,34 @@ def conv2d(data, ...@@ -66,34 +66,34 @@ def conv2d(data,
weight : tvm.relay.Expr weight : tvm.relay.Expr
The weight expressions. The weight expressions.
strides : tuple of int, optional strides : Optional[Tuple[int]]
The strides of convolution. The strides of convolution.
padding : tuple of int, optional padding : Optional[Tuple[int]]
The padding of convolution on both sides of inputs before convolution. The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional dilation : Optional[Tuple[int]]
Specifies the dilation rate to be used for dilated convolution. Specifies the dilation rate to be used for dilated convolution.
groups : int, optional groups : Optional[int]
Number of groups for grouped convolution. Number of groups for grouped convolution.
channels : int, optional channels : Optional[int]
Number of output channels of this convolution. Number of output channels of this convolution.
kernel_size : tuple of int, optional kernel_size : Optional[Tuple[int]]
The spatial of the convolution kernel. The spatial of the convolution kernel.
data_layout : str, optional data_layout : Optional[str]
Layout of the input. Layout of the input.
kernel_layout : str, optional kernel_layout : Optional[str]
Layout of the weight. Layout of the weight.
out_layout : str, optional out_layout : Optional[str]
Layout of the output, by default, out_layout is the same as data_layout Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional out_dtype : Optional[str]
Specifies the output data type for mixed precision conv2d. Specifies the output data type for mixed precision conv2d.
Returns Returns
...@@ -691,8 +691,30 @@ def dropout(data, rate=0.5): ...@@ -691,8 +691,30 @@ def dropout(data, rate=0.5):
result : tvm.relay.Expr result : tvm.relay.Expr
The result of dropout The result of dropout
""" """
result = _make.dropout(data, rate) return TupleWrapper(dropout_raw(data, rate), 2)[0]
return TupleWrapper(result, 2)[0]
def dropout_raw(data, rate=0.5):
"""Applies the dropout operation to the input array.
During training, each element of the input is set to zero with
probability ``p``. The whole array is rescaled by ``1/(1-p)``
to keep the expected sum of the input unchanged.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
rate : float, optional (default=0.5)
The probability for an element to be reset to 0.
Returns
-------
result : tvm.relay.Expr
The result of dropout
"""
return _make.dropout(data, rate)
def batch_norm(data, def batch_norm(data,
......
...@@ -23,4 +23,7 @@ from .. import register_func ...@@ -23,4 +23,7 @@ from .. import register_func
def fromtext(data, source_name=None): def fromtext(data, source_name=None):
"""Parse a Relay program.""" """Parse a Relay program."""
from tvm.relay import _parser from tvm.relay import _parser
return _parser.fromtext(data, source_name) x = _parser.fromtext(data + "\n", source_name)
if x is None:
raise Exception("cannot parse: ", data)
return x
...@@ -42,7 +42,7 @@ def _make_dense_block(data, num_layers, bn_size, growth_rate, index): ...@@ -42,7 +42,7 @@ def _make_dense_block(data, num_layers, bn_size, growth_rate, index):
layer_out = data layer_out = data
for i in range(num_layers): for i in range(num_layers):
layer_out = _make_dense_layer(layer_out, growth_rate, bn_size, layer_out = _make_dense_layer(layer_out, growth_rate, bn_size,
"(%s, %s)" % (index, i)) "%s_%s" % (index, i))
return layer_out return layer_out
def _make_transition(data, num_output_features, index): def _make_transition(data, num_output_features, index):
......
...@@ -29,7 +29,7 @@ class Type(RelayNode): ...@@ -29,7 +29,7 @@ class Type(RelayNode):
"""Compare two Relay types for structural equivalence using """Compare two Relay types for structural equivalence using
alpha equivalence. alpha equivalence.
""" """
return bool(_make._type_alpha_equal(self, other)) return bool(_make._alpha_equal(self, other))
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
*/ */
/*! /*!
* Copyright (c) 2018 by Contributors * Copyright (c) 2019 by Contributors
* \file src/tvm/relay/ir/alpha_equal.cc * \file src/tvm/relay/ir/alpha_equal.cc
* \brief Alpha equality check by deep comparing two nodes. * \brief Alpha equality check by deep comparing two nodes.
*/ */
...@@ -27,9 +27,10 @@ ...@@ -27,9 +27,10 @@
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
#include <tvm/runtime/ndarray.h> #include <tvm/runtime/ndarray.h>
#include <tvm/relay/analysis.h> #include <tvm/relay/analysis.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/attrs/nn.h>
#include "type_functor.h" #include "type_functor.h"
#include "../../lang/attr_functor.h" #include "../../lang/attr_functor.h"
namespace tvm { namespace tvm {
namespace relay { namespace relay {
...@@ -40,8 +41,8 @@ class AlphaEqualHandler: ...@@ -40,8 +41,8 @@ class AlphaEqualHandler:
public ExprFunctor<bool(const Expr&, const Expr&)>, public ExprFunctor<bool(const Expr&, const Expr&)>,
public PatternFunctor<bool(const Pattern&, const Pattern&)> { public PatternFunctor<bool(const Pattern&, const Pattern&)> {
public: public:
explicit AlphaEqualHandler(bool map_free_var) explicit AlphaEqualHandler(bool map_free_var, bool assert_mode)
: map_free_var_(map_free_var) { } : map_free_var_(map_free_var), assert_mode_(assert_mode) { }
/*! /*!
* Check equality of two nodes. * Check equality of two nodes.
...@@ -76,6 +77,9 @@ class AlphaEqualHandler: ...@@ -76,6 +77,9 @@ class AlphaEqualHandler:
return AttrEqual(lhs, rhs); return AttrEqual(lhs, rhs);
} }
bool DoubleEqual(double l, double r) {
return true;
}
/*! /*!
* Check equality of two attributes. * Check equality of two attributes.
* \param lhs The left hand operand. * \param lhs The left hand operand.
...@@ -83,18 +87,28 @@ class AlphaEqualHandler: ...@@ -83,18 +87,28 @@ class AlphaEqualHandler:
* \return The comparison result. * \return The comparison result.
*/ */
bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) { bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
if (&lhs == &rhs) return true; auto compute = [&]() {
auto lhsd = lhs.as<DictAttrsNode>(); if (&lhs == &rhs) return true;
if (lhsd) { if (auto lhsd = lhs.as<DictAttrsNode>()) {
auto rhsd = lhs.as<DictAttrsNode>(); auto rhsd = lhs.as<DictAttrsNode>();
if (!rhsd) return false; if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false; if (lhsd->dict.size() != rhsd->dict.size()) return false;
for (const auto& k : lhsd->dict) { for (const auto& k : lhsd->dict) {
if (!Equal(k.second, rhsd->dict[k.first])) return false; if (!Equal(k.second, rhsd->dict[k.first])) return false;
}
return true;
} }
return true; if (auto lhsbn = lhs.as<BatchNormAttrs>()) {
} auto rhsbn = rhs.as<BatchNormAttrs>();
return AttrsEqualHandler::Equal(lhs, rhs); if (!rhsbn) return false;
return (lhsbn->axis == rhsbn->axis)
&& DoubleEqual(lhsbn->epsilon, rhsbn->epsilon)
&& (lhsbn->center == rhsbn->center)
&& (lhsbn->scale == rhsbn->scale);
}
return AttrsEqualHandler::Equal(lhs, rhs);
};
return Compare(compute(), lhs, rhs);
} }
/*! /*!
* Check equality of two types. * Check equality of two types.
...@@ -107,6 +121,13 @@ class AlphaEqualHandler: ...@@ -107,6 +121,13 @@ class AlphaEqualHandler:
if (!lhs.defined() || !rhs.defined()) return false; if (!lhs.defined() || !rhs.defined()) return false;
return this->VisitType(lhs, rhs); return this->VisitType(lhs, rhs);
} }
bool Compare(bool result, const NodeRef& lhs, const NodeRef& rhs) {
if (assert_mode_) {
CHECK(result) << "\n" << AsText(lhs, true) << "\nis not equal to:\n" << AsText(rhs, true);
}
return result;
}
/*! /*!
* Check equality of two expressions. * Check equality of two expressions.
* *
...@@ -120,18 +141,21 @@ class AlphaEqualHandler: ...@@ -120,18 +141,21 @@ class AlphaEqualHandler:
* \return The comparison result. * \return The comparison result.
*/ */
bool ExprEqual(const Expr& lhs, const Expr& rhs) { bool ExprEqual(const Expr& lhs, const Expr& rhs) {
if (lhs.same_as(rhs)) return true; auto compute = [&]() {
if (!lhs.defined() || !rhs.defined()) return false; if (lhs.same_as(rhs)) return true;
auto it = equal_map_.find(lhs); if (!lhs.defined() || !rhs.defined()) return false;
if (it != equal_map_.end()) { auto it = equal_map_.find(lhs);
return it->second.same_as(rhs); if (it != equal_map_.end()) {
} return it->second.same_as(rhs);
if (this->VisitExpr(lhs, rhs)) { }
equal_map_[lhs] = rhs; if (this->VisitExpr(lhs, rhs)) {
return true; equal_map_[lhs] = rhs;
} else { return true;
return false; } else {
} return false;
}
};
return Compare(compute(), lhs, rhs);
} }
protected: protected:
...@@ -516,32 +540,41 @@ class AlphaEqualHandler: ...@@ -516,32 +540,41 @@ class AlphaEqualHandler:
private: private:
// whether to map open terms. // whether to map open terms.
bool map_free_var_; bool map_free_var_;
// if in assert mode, must return true, and will throw error otherwise.
bool assert_mode_;
// renaming of NodeRef to indicate two nodes equals to each other // renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_; std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
}; };
bool AlphaEqual(const Type& lhs, const Type& rhs) { bool AlphaEqual(const Type& lhs, const Type& rhs) {
return AlphaEqualHandler(false).TypeEqual(lhs, rhs); return AlphaEqualHandler(false, false).TypeEqual(lhs, rhs);
} }
bool AlphaEqual(const Expr& lhs, const Expr& rhs) { bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
return AlphaEqualHandler(false).ExprEqual(lhs, rhs); return AlphaEqualHandler(false, false).ExprEqual(lhs, rhs);
} }
// TODO(@jroesch): move to correct namespace? // TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal") TVM_REGISTER_API("relay._make._alpha_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) { .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(false).Equal(a, b); return AlphaEqualHandler(false, false).Equal(a, b);
}); });
TVM_REGISTER_API("relay._make._type_alpha_equal") TVM_REGISTER_API("relay._make._assert_alpha_equal")
.set_body_typed<bool(Type, Type)>([](Type a, Type b) { .set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(false).TypeEqual(a, b); bool alpha_equal = AlphaEqualHandler(false, true).Equal(a, b);
CHECK(alpha_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not alpha equal";
}); });
TVM_REGISTER_API("relay._make._graph_equal") TVM_REGISTER_API("relay._make._graph_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) { .set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(true).Equal(a, b); return AlphaEqualHandler(true, false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._assert_graph_equal")
.set_body_typed<void(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
bool graph_equal = AlphaEqualHandler(true, true).Equal(a, b);
CHECK(graph_equal) << AsText(a, true) << " and " << AsText(b, true) << " is not graph equal";
}); });
} // namespace relay } // namespace relay
......
...@@ -89,7 +89,7 @@ std::string Doc::str() { ...@@ -89,7 +89,7 @@ std::string Doc::str() {
return os.str(); return os.str();
} }
Doc PrintVec(const std::vector<Doc>& vec, const Doc& sep) { Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep) {
Doc seq; Doc seq;
if (vec.size() != 0) { if (vec.size() != 0) {
seq = vec[0]; seq = vec[0];
......
...@@ -6,9 +6,9 @@ ...@@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the * to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance * "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at * with the License. You may obtain a copy of the License at
* *
* http://www.apache.org/licenses/LICENSE-2.0 * http://www.apache.org/licenses/LICENSE-2.0
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
...@@ -46,7 +46,11 @@ using DocAtom = std::shared_ptr<DocAtomNode>; ...@@ -46,7 +46,11 @@ using DocAtom = std::shared_ptr<DocAtomNode>;
struct TextNode : DocAtomNode { struct TextNode : DocAtomNode {
std::string str; std::string str;
explicit TextNode(const std::string& str) : str(str) {} explicit TextNode(const std::string& str) : str(str) {
if (str.find_first_of("\t\n") != str.npos) {
LOG(WARNING) << "text node: '" << str << "' should not has tab or newline.";
}
}
}; };
struct LineNode : DocAtomNode { struct LineNode : DocAtomNode {
...@@ -91,8 +95,8 @@ class Doc { ...@@ -91,8 +95,8 @@ class Doc {
// DSL functions // DSL functions
// Render vectors of docs with a separator. e.g. PrintVec([1, 2, 3], f) -> 1f2f3 // Render vectors of docs with a separator. e.g. PrintSep([1, 2, 3], f) -> 1f2f3
Doc PrintVec(const std::vector<Doc>& vec, const Doc& sep = Doc(", ")); Doc PrintSep(const std::vector<Doc>& vec, const Doc& sep = Doc(", "));
// Print a constant bool value. // Print a constant bool value.
Doc PrintBool(bool value); Doc PrintBool(bool value);
// Print a data type. // Print a data type.
...@@ -116,7 +120,8 @@ Doc PrintConstScalar(DataType dtype, const T* data) { ...@@ -116,7 +120,8 @@ Doc PrintConstScalar(DataType dtype, const T* data) {
} else if (dtype == Bool()) { } else if (dtype == Bool()) {
return PrintBool(data[0] != 0); return PrintBool(data[0] != 0);
} else { } else {
os << dtype << "(" << data[0] << ")"; // todo(@M.K.) this is unsafe. fix.
os << data[0];
} }
return Doc(os.str()); return Doc(os.str());
} }
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
* - Otherwise, inline if the node is at the end of a scope and is used at most once. * - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/ */
#include <dmlc/json.h>
#include <tvm/relay/expr_functor.h> #include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h> #include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h> #include <tvm/relay/pattern_functor.h>
...@@ -43,6 +44,17 @@ ...@@ -43,6 +44,17 @@
namespace tvm { namespace tvm {
namespace relay { namespace relay {
Doc Brace(const Doc& d,
const std::string& open = "{",
const std::string& close = "}",
int indent = 2) {
Doc doc;
doc << open;
doc << Indent(indent, PrintNewLine() << d) << PrintNewLine();
doc << close;
return doc;
}
/*! /*!
* \brief Meta data context for PrettyPrinter. * \brief Meta data context for PrettyPrinter.
* *
...@@ -108,8 +120,10 @@ class TextMetaDataContext { ...@@ -108,8 +120,10 @@ class TextMetaDataContext {
if (it != meta_repr_.end()) { if (it != meta_repr_.end()) {
return it->second; return it->second;
} }
std::string type_key = node->type_key();
CHECK(!type_key.empty());
Array<NodeRef>& mvector = Array<NodeRef>& mvector =
meta_data_[node->type_key()]; meta_data_[type_key];
int64_t index = static_cast<int64_t>(mvector.size()); int64_t index = static_cast<int64_t>(mvector.size());
mvector.push_back(node); mvector.push_back(node);
Doc doc; Doc doc;
...@@ -117,14 +131,18 @@ class TextMetaDataContext { ...@@ -117,14 +131,18 @@ class TextMetaDataContext {
meta_repr_[node] = doc; meta_repr_[node] = doc;
return meta_repr_[node]; return meta_repr_[node];
} }
Doc PrintKeyValue(const std::string& str, const Doc& v) const {
return Doc("\"") << str << "\": " << v;
}
/*! /*!
* \brief Get the metadata section in json format. * \brief Get the metadata section in json format.
* \return the meta data string. * \return the meta data string.
*/ */
std::string GetMetaSection() const { Doc GetMetaSection() const {
if (meta_data_.size() == 0) return std::string(); if (meta_data_.size() == 0) return Doc();
return SaveJSON(Map<std::string, NodeRef>( return Doc(SaveJSON(Map<std::string, NodeRef>(meta_data_.begin(), meta_data_.end())));
meta_data_.begin(), meta_data_.end()));
} }
/*! \return whether the meta data context is empty. */ /*! \return whether the meta data context is empty. */
...@@ -172,12 +190,11 @@ class PrettyPrinter : ...@@ -172,12 +190,11 @@ class PrettyPrinter :
} }
// indent a new body // indent a new body
// TODO(jmp): indent should be an instance variable of the printer
Doc PrintBody(const NodeRef& node, int indent = 2) { Doc PrintBody(const NodeRef& node, int indent = 2) {
Doc doc; Doc doc;
Doc body; Doc body;
doc << "{"; doc << "{";
doc << Indent(indent, body << "\n" << PrintScope(node)) << "\n"; doc << Indent(indent, body << PrintNewLine() << PrintScope(node)) << PrintNewLine();
doc << "}"; doc << "}";
return doc; return doc;
} }
...@@ -203,13 +220,12 @@ class PrettyPrinter : ...@@ -203,13 +220,12 @@ class PrettyPrinter :
Doc doc; Doc doc;
doc << PrintScope(node); doc << PrintScope(node);
if (!meta_.empty()) { if (!meta_.empty()) {
doc << PrintNewLine();
if (show_meta_data_) { if (show_meta_data_) {
std::string meta_json = meta_.GetMetaSection();
// append meta data in the end. // append meta data in the end.
doc << "\n" << "/* meta data */" << "\n" << meta_json; doc << "METADATA:" << PrintNewLine() << meta_.GetMetaSection();
} else { } else {
doc << "\n" doc << "// meta data omitted. you can use show_meta_data=True to include meta data";
<< "// meta data omitted. you can use show_meta_data=True to include meta data";
} }
} }
return doc; return doc;
...@@ -361,7 +377,7 @@ class PrettyPrinter : ...@@ -361,7 +377,7 @@ class PrettyPrinter :
// wrap GNFed let in brackets // wrap GNFed let in brackets
Doc body; Doc body;
printed_expr << "{"; printed_expr << "{";
printed_expr << Indent(2, body << "\n" << VisitExpr(expr)) << "\n"; printed_expr << Indent(2, body << PrintNewLine() << VisitExpr(expr)) << PrintNewLine();
printed_expr << "}"; printed_expr << "}";
} else { } else {
printed_expr = VisitExpr(expr); printed_expr = VisitExpr(expr);
...@@ -373,7 +389,7 @@ class PrettyPrinter : ...@@ -373,7 +389,7 @@ class PrettyPrinter :
if (expr.as<VarNode>()) { if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case // This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free. // in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n"; doc_stack_.back() << "free_var " << printed_expr << PrintNewLine();
// Memoization is done in AllocVar. // Memoization is done in AllocVar.
return memo_[expr]; return memo_[expr];
} else if (inline_expr) { } else if (inline_expr) {
...@@ -422,7 +438,7 @@ class PrettyPrinter : ...@@ -422,7 +438,7 @@ class PrettyPrinter :
fields.push_back(Print(field)); fields.push_back(Print(field));
} }
Doc doc; Doc doc;
doc << "(" << PrintVec(fields); doc << "(" << PrintSep(fields);
// conform to python tuple format (1,) // conform to python tuple format (1,)
if (op->fields.size() == 1) { if (op->fields.size() == 1) {
doc << ","; doc << ",";
...@@ -460,31 +476,31 @@ class PrettyPrinter : ...@@ -460,31 +476,31 @@ class PrettyPrinter :
} }
Doc PrintFunc(const Doc& prefix, const Function& fn) { Doc PrintFunc(const Doc& prefix, const Function& fn) {
Doc doc; Doc doc;
doc << prefix; doc << prefix;
if (fn->type_params.size() > 0) { if (fn->type_params.size() > 0) {
doc << "<"; doc << "<";
std::vector<Doc> type_params; std::vector<Doc> type_params;
for (const TypeVar& tv : fn->type_params) { for (const TypeVar& tv : fn->type_params) {
type_params.push_back(AllocTypeVar(tv)); type_params.push_back(AllocTypeVar(tv));
}
doc << PrintVec(type_params);
doc << ">";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << PrintVec(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
} }
doc << PrintBody(fn->body); doc << PrintSep(type_params);
return doc; doc << ">";
}
doc << "(";
std::vector<Doc> params;
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << PrintSep(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
doc << PrintBody(fn->body);
return doc;
} }
Doc PrintMod(const Module& mod) { Doc PrintMod(const Module& mod) {
...@@ -493,13 +509,13 @@ class PrettyPrinter : ...@@ -493,13 +509,13 @@ class PrettyPrinter :
for (const auto& kv : mod->functions) { for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second); dg_ = DependencyGraph::Create(&arena_, kv.second);
std::ostringstream os;
if (counter++ != 0) { if (counter++ != 0) {
doc << "\n"; doc << PrintNewLine();
} }
std::ostringstream os;
os << "def @" << kv.first->name_hint; os << "def @" << kv.first->name_hint;
doc << PrintFunc(Doc(os.str()), kv.second); doc << PrintFunc(Doc(os.str()), kv.second);
doc << "\n"; doc << PrintNewLine();
} }
return doc; return doc;
} }
...@@ -528,7 +544,7 @@ class PrettyPrinter : ...@@ -528,7 +544,7 @@ class PrettyPrinter :
args.push_back(d); args.push_back(d);
} }
doc << Print(op->op); doc << Print(op->op);
return doc << "(" << PrintVec(args) << ")"; return doc << "(" << PrintSep(args) << ")";
} }
Doc VisitExpr_(const RefCreateNode* op) final { Doc VisitExpr_(const RefCreateNode* op) final {
...@@ -558,7 +574,7 @@ class PrettyPrinter : ...@@ -558,7 +574,7 @@ class PrettyPrinter :
clauses.push_back(clause_doc << Print(clause->lhs) << " -> " clauses.push_back(clause_doc << Print(clause->lhs) << " -> "
<< Print(clause->rhs)); << Print(clause->rhs));
} }
doc << Indent(2, body << "\n" << PrintVec(clauses, Doc("\n"))) << "\n"; doc << Indent(2, body << PrintNewLine() << PrintSep(clauses, PrintNewLine())) << PrintNewLine();
doc << "}"; doc << "}";
return doc; return doc;
} }
...@@ -570,7 +586,7 @@ class PrettyPrinter : ...@@ -570,7 +586,7 @@ class PrettyPrinter :
for (const auto& pat : p->patterns) { for (const auto& pat : p->patterns) {
pats.push_back(Print(pat)); pats.push_back(Print(pat));
} }
return doc << PrintVec(pats) << ")"; return doc << PrintSep(pats) << ")";
} }
Doc VisitPattern_(const PatternVarNode* pv) final { Doc VisitPattern_(const PatternVarNode* pv) final {
...@@ -617,7 +633,7 @@ class PrettyPrinter : ...@@ -617,7 +633,7 @@ class PrettyPrinter :
args.push_back(PrintType(t, false)); args.push_back(PrintType(t, false));
} }
doc << "["; doc << "[";
doc << PrintVec(args); doc << PrintSep(args);
doc << "]"; doc << "]";
return doc; return doc;
} }
...@@ -633,11 +649,7 @@ class PrettyPrinter : ...@@ -633,11 +649,7 @@ class PrettyPrinter :
for (NodeRef shape : node->shape) { for (NodeRef shape : node->shape) {
shapes.push_back(PrintAttr(shape)); shapes.push_back(PrintAttr(shape));
} }
doc << PrintVec(shapes); doc << PrintSep(shapes);
// conform to python tuple format (1,)
if (node->shape.size() == 1) {
doc << ",";
}
return doc << "), " << PrintDType(node->dtype) << "]"; return doc << "), " << PrintDType(node->dtype) << "]";
} }
...@@ -647,7 +659,7 @@ class PrettyPrinter : ...@@ -647,7 +659,7 @@ class PrettyPrinter :
fields.push_back(Print(field)); fields.push_back(Print(field));
} }
Doc doc; Doc doc;
doc << "(" << PrintVec(fields); doc << "(" << PrintSep(fields);
// conform to python tuple format (1,) // conform to python tuple format (1,)
if (node->fields.size() == 1) { if (node->fields.size() == 1) {
doc << ","; doc << ",";
...@@ -664,14 +676,14 @@ class PrettyPrinter : ...@@ -664,14 +676,14 @@ class PrettyPrinter :
for (Type type_param : node->type_params) { for (Type type_param : node->type_params) {
type_params.push_back(Print(type_param)); type_params.push_back(Print(type_param));
} }
doc << PrintVec(type_params); doc << PrintSep(type_params);
doc << ">"; doc << ">";
} }
std::vector<Doc> arg_types; std::vector<Doc> arg_types;
for (Type arg_type : node->arg_types) { for (Type arg_type : node->arg_types) {
arg_types.push_back(Print(arg_type)); arg_types.push_back(Print(arg_type));
} }
return doc << "(" << PrintVec(arg_types) << ") -> " << Print(node->ret_type); return doc << "(" << PrintSep(arg_types) << ") -> " << Print(node->ret_type);
} }
Doc VisitType_(const RefTypeNode* node) final { Doc VisitType_(const RefTypeNode* node) final {
...@@ -710,7 +722,7 @@ class PrettyPrinter : ...@@ -710,7 +722,7 @@ class PrettyPrinter :
for (NodePtr<Node> val : op->data) { for (NodePtr<Node> val : op->data) {
arr_vals.push_back(PrintAttr(NodeRef(val))); arr_vals.push_back(PrintAttr(NodeRef(val)));
} }
doc << PrintVec(arr_vals); doc << PrintSep(arr_vals);
doc << "]"; doc << "]";
return doc; return doc;
} }
...@@ -771,7 +783,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor { ...@@ -771,7 +783,9 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
} }
void Visit(const char* key, double* value) final { void Visit(const char* key, double* value) final {
PrintKV(key, *value); Doc doc;
doc << key << "=" << *value << "f";
docs->push_back(doc);
} }
void Visit(const char* key, int64_t* value) final { void Visit(const char* key, int64_t* value) final {
PrintKV(key, *value); PrintKV(key, *value);
...@@ -843,7 +857,7 @@ std::string PrettyPrint_(const NodeRef& node, ...@@ -843,7 +857,7 @@ std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data, bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) { runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc; Doc doc;
doc << "v0.0.3" << "\n" doc << "v0.0.3" << PrintNewLine()
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node); << PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str(); return doc.str();
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
# under the License. # under the License.
import tvm import tvm
from tvm import relay from tvm import relay
from tvm.relay.analysis import alpha_equal from tvm.relay.analysis import alpha_equal, assert_alpha_equal
from nose.tools import nottest, raises from nose.tools import nottest, raises
from numpy import isclose from numpy import isclose
from typing import Union from typing import Union
...@@ -60,12 +60,9 @@ TYPES = { ...@@ -60,12 +60,9 @@ TYPES = {
"float16x4", "float16x4",
} }
def assert_alpha_equal(a, b):
if not alpha_equal(a, b):
raise Exception("lhs is: ", str(a), "rhs is: ", str(b))
def roundtrip(expr): def roundtrip(expr):
assert_alpha_equal(relay.fromtext(str(expr)), expr) x = relay.fromtext(str(expr))
assert_alpha_equal(x, expr)
def parse_text(code): def parse_text(code):
...@@ -112,6 +109,16 @@ def test_comments(): ...@@ -112,6 +109,16 @@ def test_comments():
UNIT UNIT
) )
assert parses_as(
"""
/* This is a block comment!
/*Block comment is recursive!*/
*/
()
""",
UNIT
)
def test_int_literal(): def test_int_literal():
assert isinstance(parse_text("1"), relay.Constant) assert isinstance(parse_text("1"), relay.Constant)
...@@ -224,7 +231,7 @@ def test_let(): ...@@ -224,7 +231,7 @@ def test_let():
def test_seq(): def test_seq():
assert parses_as( assert parses_as(
"(); ()", "();; ()",
relay.Let( relay.Let(
_, _,
UNIT, UNIT,
...@@ -538,7 +545,7 @@ def test_tensor_type(): ...@@ -538,7 +545,7 @@ def test_tensor_type():
) )
assert parses_as( assert parses_as(
"let %_ : Tensor[(1,), float32] = (); ()", "let %_ : Tensor[(1), float32] = (); ()",
relay.Let( relay.Let(
relay.Var("_", relay.TensorType((1,), "float32")), relay.Var("_", relay.TensorType((1,), "float32")),
UNIT, UNIT,
......
...@@ -15,14 +15,27 @@ ...@@ -15,14 +15,27 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import tvm import tvm
from tvm import relay
import tvm.relay.testing import tvm.relay.testing
import numpy as np import numpy as np
from tvm import relay from tvm.relay import Expr
from tvm.relay.analysis import alpha_equal, assert_alpha_equal, assert_graph_equal, free_vars
do_print = [False] do_print = [False]
SEMVER = "v0.0.3\n" SEMVER = "v0.0.3\n"
def astext(p, graph_equal=False):
txt = p.astext()
if isinstance(p, Expr) and free_vars(p):
return txt
x = relay.fromtext(txt)
if graph_equal:
assert_graph_equal(x, p)
else:
assert_alpha_equal(x, p)
return txt
def show(text): def show(text):
if do_print[0]: if do_print[0]:
print("---------------------------") print("---------------------------")
...@@ -35,8 +48,8 @@ def test_func(): ...@@ -35,8 +48,8 @@ def test_func():
z = relay.add(x, one) z = relay.add(x, one)
z = relay.add(z, z) z = relay.add(z, z)
f = relay.Function([x, y], z) f = relay.Function([x, y], z)
show(z.astext()) show(astext(z))
show(f.astext()) show(astext(f))
def test_env(): def test_env():
...@@ -47,7 +60,7 @@ def test_env(): ...@@ -47,7 +60,7 @@ def test_env():
f = relay.Function([x, y], z) f = relay.Function([x, y], z)
env = relay.Module() env = relay.Module()
env["myf"] = f env["myf"] = f
text = env.astext() text = astext(env)
assert "def @myf" in text assert "def @myf" in text
assert "def @myf" in str(env) assert "def @myf" in str(env)
assert "add(%0, %0) /* ty=float32 */" in text assert "add(%0, %0) /* ty=float32 */" in text
...@@ -65,7 +78,7 @@ def test_meta_data(): ...@@ -65,7 +78,7 @@ def test_meta_data():
padding=(1, 1), padding=(1, 1),
channels=2) channels=2)
f = relay.Function([x, w], z) f = relay.Function([x, w], z)
text = f.astext() text = astext(f, graph_equal=True)
text_no_meta = str(f) text_no_meta = str(f)
assert "channels=2" in text assert "channels=2" in text
assert "channels=2" in text_no_meta assert "channels=2" in text_no_meta
...@@ -73,25 +86,22 @@ def test_meta_data(): ...@@ -73,25 +86,22 @@ def test_meta_data():
assert "meta[Variable][0]" in text_no_meta assert "meta[Variable][0]" in text_no_meta
assert "type_key" in text assert "type_key" in text
assert "type_key" not in text_no_meta assert "type_key" not in text_no_meta
show(text)
show(f)
text = relay.const([1,2,3]).astext() text = astext(relay.const([1,2,3]))
assert "meta[relay.Constant][0]" in text assert "meta[relay.Constant][0]" in text
show(text)
def test_call_attrs(): def test_call_attrs():
x = relay.var("x") x = relay.var("x")
# non default args # non default args
z = relay.nn.softmax(x, axis=2) z = relay.nn.softmax(x, axis=2)
assert "axis=2" in z.astext() assert "axis=2" in astext(z)
# default args # default args
z = relay.nn.softmax(x) z = relay.nn.softmax(x)
assert "softmax(%x)" in z.astext() assert "softmax(%x)" in astext(z)
# non default args # non default args
z = relay.expand_dims(x, axis=2, num_newaxis=2) z = relay.expand_dims(x, axis=2, num_newaxis=2)
assert "num_newaxis=2" in z.astext() assert "num_newaxis=2" in astext(z)
def test_let_if_scope(): def test_let_if_scope():
...@@ -111,68 +121,72 @@ def test_let_if_scope(): ...@@ -111,68 +121,72 @@ def test_let_if_scope():
result = sb.get() result = sb.get()
f = relay.Function([x, y, cond], result) f = relay.Function([x, y, cond], result)
text = f.astext() text = astext(f)
assert text.count("{") == 4 assert text.count("{") == 4
assert "%cond: bool" in text assert "%cond: bool" in text
show(f.astext()) show(astext(f))
def test_variable_name(): def test_variable_name():
# avoid pure number even if the namehint is pure number # avoid pure number even if the namehint is pure number
v1 = relay.var("1") v1 = relay.var("1")
assert "%v1" in v1.astext() assert "%v1" in astext(v1)
def test_mlp(): def test_mlp():
net, params = tvm.relay.testing.mlp.get_workload(batch_size=1) net, params = tvm.relay.testing.mlp.get_workload(batch_size=1)
net.astext() astext(net)
def test_resnet(): def test_resnet():
net, params = tvm.relay.testing.resnet.get_workload(batch_size=1) net, params = tvm.relay.testing.resnet.get_workload(batch_size=1)
net.astext() astext(net)
def test_mobilenet(): def test_mobilenet():
net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1) net, params = tvm.relay.testing.mobilenet.get_workload(batch_size=1)
net.astext() astext(net)
def test_dqn(): def test_dqn():
net, params = tvm.relay.testing.dqn.get_workload(batch_size=1) net, params = tvm.relay.testing.dqn.get_workload(batch_size=1)
net.astext() astext(net)
def test_dcgan(): def test_dcgan():
net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1) net, params = tvm.relay.testing.dcgan.get_workload(batch_size=1)
net.astext() astext(net)
def test_lstm(): def test_lstm():
net, params = tvm.relay.testing.lstm.get_workload(1, 1)
astext(net)
net, params = tvm.relay.testing.lstm.get_workload(4, 4) net, params = tvm.relay.testing.lstm.get_workload(4, 4)
net.astext() astext(net)
def test_inception_v3(): def test_inception_v3():
net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1) net, params = tvm.relay.testing.inception_v3.get_workload(batch_size=1)
net.astext() astext(net)
def test_squeezenet(): def test_squeezenet():
for version in ['1.0', '1.1']: for version in ['1.0', '1.1']:
net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version) net, params = tvm.relay.testing.squeezenet.get_workload(batch_size=1, version=version)
net.astext() astext(net)
def test_vgg(): def test_vgg():
net, params = tvm.relay.testing.vgg.get_workload(batch_size=1) net, params = tvm.relay.testing.vgg.get_workload(batch_size=1)
net.astext() astext(net)
def test_densenet(): def test_densenet():
net, params = tvm.relay.testing.densenet.get_workload(batch_size=1) net, params = tvm.relay.testing.densenet.get_workload(batch_size=1)
net.astext() astext(net)
def test_call_node_order(): def test_call_node_order():
x = relay.var("x") x = relay.var("x")
y = relay.var("y") y = relay.var("y")
assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \ prog = relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])])
assert astext(prog) == SEMVER + \
("%0 = fn (%y) {\n" ("%0 = fn (%y) {\n"
" %y\n" " %y\n"
"};\n" "};\n"
...@@ -185,17 +199,25 @@ def test_call_node_order(): ...@@ -185,17 +199,25 @@ def test_call_node_order():
def test_let_inlining(): def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)]) tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x") x = relay.var("x")
assert relay.Let(x, tup, tup).astext() == SEMVER + \ assert astext(relay.Let(x, tup, tup)) == SEMVER + \
("%0 = (0, 0);\n" ("%0 = (0, 0);\n"
"let %x = %0;\n" "let %x = %0;\n"
"%0") "%0")
assert relay.Let(x, tup, x).astext() == SEMVER + \ assert astext(relay.Let(x, tup, x)) == SEMVER + \
("let %x = (0, 0);\n" ("let %x = (0, 0);\n"
"%x") "%x")
def test_zeros():
x = relay.op.zeros([], "float32")
astext(x)
if __name__ == "__main__": if __name__ == "__main__":
do_print[0] = True do_print[0] = True
test_lstm()
test_zeros()
test_meta_data()
test_let_inlining()
test_resnet() test_resnet()
test_mobilenet() test_mobilenet()
test_mlp() test_mlp()
...@@ -207,9 +229,7 @@ if __name__ == "__main__": ...@@ -207,9 +229,7 @@ if __name__ == "__main__":
test_densenet() test_densenet()
test_func() test_func()
test_env() test_env()
test_meta_data()
test_call_attrs() test_call_attrs()
test_let_if_scope() test_let_if_scope()
test_variable_name() test_variable_name()
test_call_node_order() test_call_node_order()
test_let_inlining()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment