Commit 7fb9557b by 雾雨魔理沙 Committed by Jared Roesch

[Relay] Roundtrip part of pretty printer and parser (#3460)

* init

fix rebase

lint

fix cmake

try again

fix ci

* add gitignore

* fix format

* do not include .interp and .tokens
parent 7b988016
......@@ -217,7 +217,7 @@ patched.txt
.mypy_cache/
.pyre/
# pipenv file
# pipenv files
Pipfile
Pipfile.lock
......@@ -225,5 +225,10 @@ Pipfile.lock
conda/Dockerfile.cuda*
conda/pkg
# nix files
.envrc
*.nix
\ No newline at end of file
*.nix
# antlr files
*.tokens
*.interp
\ No newline at end of file
......@@ -14,21 +14,49 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
if(USE_ANTLR)
set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
find_program(ANTLR4 antlr4)
if (NOT ANTLR4)
file(GLOB_RECURSE ANTLR4JAR
/usr/local/lib/antlr-*-complete.jar
/usr/local/Cellar/*antlr-*-complete.jar)
# Get the first element of the list of antlr jars.
# Sort and reverse the list so the item selected is the highest
# version in lib or else in Cellar if no lib installation exists.
list(SORT ANTLR4JAR)
list(REVERSE ANTLR4JAR)
list(GET ANTLR4JAR 0 ANTLR4JAR)
set(JAVA_HOME $ENV{JAVA_HOME})
if (NOT DEFINED JAVA_HOME)
# Hack to get system to search for Java itself.
set(JAVA_HOME "/usr")
endif()
set(ANTLR4 ${JAVA_HOME}/bin/java -jar ${ANTLR4JAR})
endif()
if(ANTLR4)
set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
set(RELAY_PARSER
${RELAY_PARSER_DIR}/py3/RelayVisitor.py
${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py)
set(RELAY_PARSER
${RELAY_PARSER_DIR}/py3/RelayVisitor.py
${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py)
# Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND antlr4 -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR})
# Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR})
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
else()
message(FATAL_ERROR "Can't find ANTLR4")
endif()
endif(USE_ANTLR)
......@@ -206,7 +206,7 @@ class ParseTreeToRelayIR(RelayVisitor):
if node_type == RelayLexer.NAT:
return int(node_text)
if node_type == RelayLexer.FLOAT:
return float(node_text)
return float(node_text[:-1])
if node_type == RelayLexer.BOOL_LIT:
if node_text == "True":
return True
......@@ -375,6 +375,8 @@ class ParseTreeToRelayIR(RelayVisitor):
self.mk_typ(name, ty.Kind.Type)
var_list, attr_list = self.visit(ctx.argList())
if var_list is None:
var_list = []
ret_type = self.getType_(ctx.type_())
body = self.visit(ctx.body())
......@@ -387,7 +389,6 @@ class ParseTreeToRelayIR(RelayVisitor):
self.exit_var_scope()
attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None
return expr.Function(var_list, body, ret_type, type_params, attrs)
@spanify
......
......@@ -19,7 +19,7 @@
grammar Relay;
SEMVER: 'v0.0.2' ;
SEMVER: 'v0.0.3' ;
// Lexing
// comments
......@@ -52,10 +52,9 @@ BOOL_LIT
;
// non-negative floats
FLOAT
: NAT '.' NAT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| NAT EXP // 1e10 3e4
;
fragment PREFLOAT : NAT ('.' NAT)? EXP?; // 1.35, 1.35E-9, 0.3, 4.5, 1, 1e10 3e4
FLOAT : PREFLOAT 'f';
// non-negative ints
NAT: DIGIT+ ;
......
Relay* binary
Relay* linguist-generated=true
Relay* linguist-detectable=false
T__0=1
T__1=2
T__2=3
T__3=4
T__4=5
T__5=6
T__6=7
T__7=8
T__8=9
T__9=10
T__10=11
T__11=12
T__12=13
T__13=14
T__14=15
T__15=16
T__16=17
T__17=18
SEMVER=19
WS=20
LINE_COMMENT=21
COMMENT=22
MUL=23
DIV=24
ADD=25
SUB=26
LT=27
GT=28
LE=29
GE=30
EQ=31
NE=32
GLOBAL_VAR=33
LOCAL_VAR=34
GRAPH_VAR=35
MUT=36
BOOL_LIT=37
FLOAT=38
NAT=39
CNAME=40
'('=1
')'=2
','=3
'['=4
']'=5
'if'=6
'else'=7
'let'=8
'='=9
';'=10
'{'=11
'}'=12
'fn'=13
'->'=14
'def'=15
':'=16
'Tensor'=17
'_'=18
'v0.0.2'=19
'*'=23
'/'=24
'+'=25
'-'=26
'<'=27
'>'=28
'<='=29
'>='=30
'=='=31
'!='=32
'mut'=36
token literal names:
null
'('
')'
','
'['
']'
'if'
'else'
'let'
'='
';'
'{'
'}'
'fn'
'->'
'def'
':'
'Tensor'
'_'
'v0.0.2'
null
null
null
'*'
'/'
'+'
'-'
'<'
'>'
'<='
'>='
'=='
'!='
null
null
null
'mut'
null
null
null
null
token symbolic names:
null
null
null
null
null
null
null
null
null
null
null
null
null
null
null
null
null
null
null
SEMVER
WS
LINE_COMMENT
COMMENT
MUL
DIV
ADD
SUB
LT
GT
LE
GE
EQ
NE
GLOBAL_VAR
LOCAL_VAR
GRAPH_VAR
MUT
BOOL_LIT
FLOAT
NAT
CNAME
rule names:
T__0
T__1
T__2
T__3
T__4
T__5
T__6
T__7
T__8
T__9
T__10
T__11
T__12
T__13
T__14
T__15
T__16
T__17
SEMVER
WS
LINE_COMMENT
COMMENT
MUL
DIV
ADD
SUB
LT
GT
LE
GE
EQ
NE
GLOBAL_VAR
LOCAL_VAR
GRAPH_VAR
MUT
BOOL_LIT
FLOAT
NAT
EXP
CNAME
LETTER
DIGIT
channel names:
DEFAULT_TOKEN_CHANNEL
HIDDEN
mode names:
DEFAULT_MODE
atn:
[3, 24715, 42794, 33075, 47597, 16764, 15335, 30598, 22884, 2, 42, 267, 8, 1, 4, 2, 9, 2, 4, 3, 9, 3, 4, 4, 9, 4, 4, 5, 9, 5, 4, 6, 9, 6, 4, 7, 9, 7, 4, 8, 9, 8, 4, 9, 9, 9, 4, 10, 9, 10, 4, 11, 9, 11, 4, 12, 9, 12, 4, 13, 9, 13, 4, 14, 9, 14, 4, 15, 9, 15, 4, 16, 9, 16, 4, 17, 9, 17, 4, 18, 9, 18, 4, 19, 9, 19, 4, 20, 9, 20, 4, 21, 9, 21, 4, 22, 9, 22, 4, 23, 9, 23, 4, 24, 9, 24, 4, 25, 9, 25, 4, 26, 9, 26, 4, 27, 9, 27, 4, 28, 9, 28, 4, 29, 9, 29, 4, 30, 9, 30, 4, 31, 9, 31, 4, 32, 9, 32, 4, 33, 9, 33, 4, 34, 9, 34, 4, 35, 9, 35, 4, 36, 9, 36, 4, 37, 9, 37, 4, 38, 9, 38, 4, 39, 9, 39, 4, 40, 9, 40, 4, 41, 9, 41, 4, 42, 9, 42, 4, 43, 9, 43, 4, 44, 9, 44, 3, 2, 3, 2, 3, 3, 3, 3, 3, 4, 3, 4, 3, 5, 3, 5, 3, 6, 3, 6, 3, 7, 3, 7, 3, 7, 3, 8, 3, 8, 3, 8, 3, 8, 3, 8, 3, 9, 3, 9, 3, 9, 3, 9, 3, 10, 3, 10, 3, 11, 3, 11, 3, 12, 3, 12, 3, 13, 3, 13, 3, 14, 3, 14, 3, 14, 3, 15, 3, 15, 3, 15, 3, 16, 3, 16, 3, 16, 3, 16, 3, 17, 3, 17, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 18, 3, 19, 3, 19, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 20, 3, 21, 6, 21, 149, 10, 21, 13, 21, 14, 21, 150, 3, 21, 3, 21, 3, 22, 3, 22, 3, 22, 3, 22, 7, 22, 159, 10, 22, 12, 22, 14, 22, 162, 11, 22, 3, 22, 3, 22, 3, 22, 3, 22, 3, 23, 3, 23, 3, 23, 3, 23, 7, 23, 172, 10, 23, 12, 23, 14, 23, 175, 11, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 23, 3, 24, 3, 24, 3, 25, 3, 25, 3, 26, 3, 26, 3, 27, 3, 27, 3, 28, 3, 28, 3, 29, 3, 29, 3, 30, 3, 30, 3, 30, 3, 31, 3, 31, 3, 31, 3, 32, 3, 32, 3, 32, 3, 33, 3, 33, 3, 33, 3, 34, 3, 34, 3, 34, 3, 35, 3, 35, 3, 35, 3, 36, 3, 36, 3, 36, 3, 37, 3, 37, 3, 37, 3, 37, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 3, 38, 5, 38, 228, 10, 38, 3, 39, 3, 39, 3, 39, 3, 39, 5, 39, 234, 10, 39, 3, 39, 3, 39, 3, 39, 5, 39, 239, 10, 39, 3, 40, 6, 40, 242, 10, 40, 13, 40, 14, 40, 243, 3, 41, 3, 41, 5, 41, 248, 10, 41, 3, 41, 3, 41, 3, 42, 3, 42, 5, 42, 254, 10, 42, 3, 42, 3, 42, 3, 42, 7, 42, 259, 10, 42, 12, 42, 14, 42, 262, 11, 42, 3, 43, 3, 43, 3, 44, 3, 44, 4, 160, 173, 2, 45, 3, 3, 5, 4, 7, 5, 9, 6, 11, 7, 13, 8, 15, 9, 17, 10, 19, 11, 21, 12, 23, 13, 25, 14, 27, 15, 29, 16, 31, 17, 33, 18, 35, 19, 37, 20, 39, 21, 41, 22, 43, 23, 45, 24, 47, 25, 49, 26, 51, 27, 53, 28, 55, 29, 57, 30, 59, 31, 61, 32, 63, 33, 65, 34, 67, 35, 69, 36, 71, 37, 73, 38, 75, 39, 77, 40, 79, 41, 81, 2, 83, 42, 85, 2, 87, 2, 3, 2, 7, 5, 2, 11, 12, 15, 15, 34, 34, 4, 2, 71, 71, 103, 103, 4, 2, 45, 45, 47, 47, 4, 2, 67, 92, 99, 124, 3, 2, 50, 59, 2, 275, 2, 3, 3, 2, 2, 2, 2, 5, 3, 2, 2, 2, 2, 7, 3, 2, 2, 2, 2, 9, 3, 2, 2, 2, 2, 11, 3, 2, 2, 2, 2, 13, 3, 2, 2, 2, 2, 15, 3, 2, 2, 2, 2, 17, 3, 2, 2, 2, 2, 19, 3, 2, 2, 2, 2, 21, 3, 2, 2, 2, 2, 23, 3, 2, 2, 2, 2, 25, 3, 2, 2, 2, 2, 27, 3, 2, 2, 2, 2, 29, 3, 2, 2, 2, 2, 31, 3, 2, 2, 2, 2, 33, 3, 2, 2, 2, 2, 35, 3, 2, 2, 2, 2, 37, 3, 2, 2, 2, 2, 39, 3, 2, 2, 2, 2, 41, 3, 2, 2, 2, 2, 43, 3, 2, 2, 2, 2, 45, 3, 2, 2, 2, 2, 47, 3, 2, 2, 2, 2, 49, 3, 2, 2, 2, 2, 51, 3, 2, 2, 2, 2, 53, 3, 2, 2, 2, 2, 55, 3, 2, 2, 2, 2, 57, 3, 2, 2, 2, 2, 59, 3, 2, 2, 2, 2, 61, 3, 2, 2, 2, 2, 63, 3, 2, 2, 2, 2, 65, 3, 2, 2, 2, 2, 67, 3, 2, 2, 2, 2, 69, 3, 2, 2, 2, 2, 71, 3, 2, 2, 2, 2, 73, 3, 2, 2, 2, 2, 75, 3, 2, 2, 2, 2, 77, 3, 2, 2, 2, 2, 79, 3, 2, 2, 2, 2, 83, 3, 2, 2, 2, 3, 89, 3, 2, 2, 2, 5, 91, 3, 2, 2, 2, 7, 93, 3, 2, 2, 2, 9, 95, 3, 2, 2, 2, 11, 97, 3, 2, 2, 2, 13, 99, 3, 2, 2, 2, 15, 102, 3, 2, 2, 2, 17, 107, 3, 2, 2, 2, 19, 111, 3, 2, 2, 2, 21, 113, 3, 2, 2, 2, 23, 115, 3, 2, 2, 2, 25, 117, 3, 2, 2, 2, 27, 119, 3, 2, 2, 2, 29, 122, 3, 2, 2, 2, 31, 125, 3, 2, 2, 2, 33, 129, 3, 2, 2, 2, 35, 131, 3, 2, 2, 2, 37, 138, 3, 2, 2, 2, 39, 140, 3, 2, 2, 2, 41, 148, 3, 2, 2, 2, 43, 154, 3, 2, 2, 2, 45, 167, 3, 2, 2, 2, 47, 181, 3, 2, 2, 2, 49, 183, 3, 2, 2, 2, 51, 185, 3, 2, 2, 2, 53, 187, 3, 2, 2, 2, 55, 189, 3, 2, 2, 2, 57, 191, 3, 2, 2, 2, 59, 193, 3, 2, 2, 2, 61, 196, 3, 2, 2, 2, 63, 199, 3, 2, 2, 2, 65, 202, 3, 2, 2, 2, 67, 205, 3, 2, 2, 2, 69, 208, 3, 2, 2, 2, 71, 211, 3, 2, 2, 2, 73, 214, 3, 2, 2, 2, 75, 227, 3, 2, 2, 2, 77, 238, 3, 2, 2, 2, 79, 241, 3, 2, 2, 2, 81, 245, 3, 2, 2, 2, 83, 253, 3, 2, 2, 2, 85, 263, 3, 2, 2, 2, 87, 265, 3, 2, 2, 2, 89, 90, 7, 42, 2, 2, 90, 4, 3, 2, 2, 2, 91, 92, 7, 43, 2, 2, 92, 6, 3, 2, 2, 2, 93, 94, 7, 46, 2, 2, 94, 8, 3, 2, 2, 2, 95, 96, 7, 93, 2, 2, 96, 10, 3, 2, 2, 2, 97, 98, 7, 95, 2, 2, 98, 12, 3, 2, 2, 2, 99, 100, 7, 107, 2, 2, 100, 101, 7, 104, 2, 2, 101, 14, 3, 2, 2, 2, 102, 103, 7, 103, 2, 2, 103, 104, 7, 110, 2, 2, 104, 105, 7, 117, 2, 2, 105, 106, 7, 103, 2, 2, 106, 16, 3, 2, 2, 2, 107, 108, 7, 110, 2, 2, 108, 109, 7, 103, 2, 2, 109, 110, 7, 118, 2, 2, 110, 18, 3, 2, 2, 2, 111, 112, 7, 63, 2, 2, 112, 20, 3, 2, 2, 2, 113, 114, 7, 61, 2, 2, 114, 22, 3, 2, 2, 2, 115, 116, 7, 125, 2, 2, 116, 24, 3, 2, 2, 2, 117, 118, 7, 127, 2, 2, 118, 26, 3, 2, 2, 2, 119, 120, 7, 104, 2, 2, 120, 121, 7, 112, 2, 2, 121, 28, 3, 2, 2, 2, 122, 123, 7, 47, 2, 2, 123, 124, 7, 64, 2, 2, 124, 30, 3, 2, 2, 2, 125, 126, 7, 102, 2, 2, 126, 127, 7, 103, 2, 2, 127, 128, 7, 104, 2, 2, 128, 32, 3, 2, 2, 2, 129, 130, 7, 60, 2, 2, 130, 34, 3, 2, 2, 2, 131, 132, 7, 86, 2, 2, 132, 133, 7, 103, 2, 2, 133, 134, 7, 112, 2, 2, 134, 135, 7, 117, 2, 2, 135, 136, 7, 113, 2, 2, 136, 137, 7, 116, 2, 2, 137, 36, 3, 2, 2, 2, 138, 139, 7, 97, 2, 2, 139, 38, 3, 2, 2, 2, 140, 141, 7, 120, 2, 2, 141, 142, 7, 50, 2, 2, 142, 143, 7, 48, 2, 2, 143, 144, 7, 50, 2, 2, 144, 145, 7, 48, 2, 2, 145, 146, 7, 52, 2, 2, 146, 40, 3, 2, 2, 2, 147, 149, 9, 2, 2, 2, 148, 147, 3, 2, 2, 2, 149, 150, 3, 2, 2, 2, 150, 148, 3, 2, 2, 2, 150, 151, 3, 2, 2, 2, 151, 152, 3, 2, 2, 2, 152, 153, 8, 21, 2, 2, 153, 42, 3, 2, 2, 2, 154, 155, 7, 49, 2, 2, 155, 156, 7, 49, 2, 2, 156, 160, 3, 2, 2, 2, 157, 159, 11, 2, 2, 2, 158, 157, 3, 2, 2, 2, 159, 162, 3, 2, 2, 2, 160, 161, 3, 2, 2, 2, 160, 158, 3, 2, 2, 2, 161, 163, 3, 2, 2, 2, 162, 160, 3, 2, 2, 2, 163, 164, 7, 12, 2, 2, 164, 165, 3, 2, 2, 2, 165, 166, 8, 22, 2, 2, 166, 44, 3, 2, 2, 2, 167, 168, 7, 49, 2, 2, 168, 169, 7, 44, 2, 2, 169, 173, 3, 2, 2, 2, 170, 172, 11, 2, 2, 2, 171, 170, 3, 2, 2, 2, 172, 175, 3, 2, 2, 2, 173, 174, 3, 2, 2, 2, 173, 171, 3, 2, 2, 2, 174, 176, 3, 2, 2, 2, 175, 173, 3, 2, 2, 2, 176, 177, 7, 44, 2, 2, 177, 178, 7, 49, 2, 2, 178, 179, 3, 2, 2, 2, 179, 180, 8, 23, 2, 2, 180, 46, 3, 2, 2, 2, 181, 182, 7, 44, 2, 2, 182, 48, 3, 2, 2, 2, 183, 184, 7, 49, 2, 2, 184, 50, 3, 2, 2, 2, 185, 186, 7, 45, 2, 2, 186, 52, 3, 2, 2, 2, 187, 188, 7, 47, 2, 2, 188, 54, 3, 2, 2, 2, 189, 190, 7, 62, 2, 2, 190, 56, 3, 2, 2, 2, 191, 192, 7, 64, 2, 2, 192, 58, 3, 2, 2, 2, 193, 194, 7, 62, 2, 2, 194, 195, 7, 63, 2, 2, 195, 60, 3, 2, 2, 2, 196, 197, 7, 64, 2, 2, 197, 198, 7, 63, 2, 2, 198, 62, 3, 2, 2, 2, 199, 200, 7, 63, 2, 2, 200, 201, 7, 63, 2, 2, 201, 64, 3, 2, 2, 2, 202, 203, 7, 35, 2, 2, 203, 204, 7, 63, 2, 2, 204, 66, 3, 2, 2, 2, 205, 206, 7, 66, 2, 2, 206, 207, 5, 83, 42, 2, 207, 68, 3, 2, 2, 2, 208, 209, 7, 39, 2, 2, 209, 210, 5, 83, 42, 2, 210, 70, 3, 2, 2, 2, 211, 212, 7, 39, 2, 2, 212, 213, 5, 79, 40, 2, 213, 72, 3, 2, 2, 2, 214, 215, 7, 111, 2, 2, 215, 216, 7, 119, 2, 2, 216, 217, 7, 118, 2, 2, 217, 74, 3, 2, 2, 2, 218, 219, 7, 86, 2, 2, 219, 220, 7, 116, 2, 2, 220, 221, 7, 119, 2, 2, 221, 228, 7, 103, 2, 2, 222, 223, 7, 72, 2, 2, 223, 224, 7, 99, 2, 2, 224, 225, 7, 110, 2, 2, 225, 226, 7, 117, 2, 2, 226, 228, 7, 103, 2, 2, 227, 218, 3, 2, 2, 2, 227, 222, 3, 2, 2, 2, 228, 76, 3, 2, 2, 2, 229, 230, 5, 79, 40, 2, 230, 231, 7, 48, 2, 2, 231, 233, 5, 79, 40, 2, 232, 234, 5, 81, 41, 2, 233, 232, 3, 2, 2, 2, 233, 234, 3, 2, 2, 2, 234, 239, 3, 2, 2, 2, 235, 236, 5, 79, 40, 2, 236, 237, 5, 81, 41, 2, 237, 239, 3, 2, 2, 2, 238, 229, 3, 2, 2, 2, 238, 235, 3, 2, 2, 2, 239, 78, 3, 2, 2, 2, 240, 242, 5, 87, 44, 2, 241, 240, 3, 2, 2, 2, 242, 243, 3, 2, 2, 2, 243, 241, 3, 2, 2, 2, 243, 244, 3, 2, 2, 2, 244, 80, 3, 2, 2, 2, 245, 247, 9, 3, 2, 2, 246, 248, 9, 4, 2, 2, 247, 246, 3, 2, 2, 2, 247, 248, 3, 2, 2, 2, 248, 249, 3, 2, 2, 2, 249, 250, 5, 79, 40, 2, 250, 82, 3, 2, 2, 2, 251, 254, 7, 97, 2, 2, 252, 254, 5, 85, 43, 2, 253, 251, 3, 2, 2, 2, 253, 252, 3, 2, 2, 2, 254, 260, 3, 2, 2, 2, 255, 259, 7, 97, 2, 2, 256, 259, 5, 85, 43, 2, 257, 259, 5, 87, 44, 2, 258, 255, 3, 2, 2, 2, 258, 256, 3, 2, 2, 2, 258, 257, 3, 2, 2, 2, 259, 262, 3, 2, 2, 2, 260, 258, 3, 2, 2, 2, 260, 261, 3, 2, 2, 2, 261, 84, 3, 2, 2, 2, 262, 260, 3, 2, 2, 2, 263, 264, 9, 5, 2, 2, 264, 86, 3, 2, 2, 2, 265, 266, 9, 6, 2, 2, 266, 88, 3, 2, 2, 2, 14, 2, 150, 160, 173, 227, 233, 238, 243, 247, 253, 258, 260, 3, 8, 2, 2]
\ No newline at end of file
T__0=1
T__1=2
T__2=3
T__3=4
T__4=5
T__5=6
T__6=7
T__7=8
T__8=9
T__9=10
T__10=11
T__11=12
T__12=13
T__13=14
T__14=15
T__15=16
T__16=17
T__17=18
SEMVER=19
WS=20
LINE_COMMENT=21
COMMENT=22
MUL=23
DIV=24
ADD=25
SUB=26
LT=27
GT=28
LE=29
GE=30
EQ=31
NE=32
GLOBAL_VAR=33
LOCAL_VAR=34
GRAPH_VAR=35
MUT=36
BOOL_LIT=37
FLOAT=38
NAT=39
CNAME=40
'('=1
')'=2
','=3
'['=4
']'=5
'if'=6
'else'=7
'let'=8
'='=9
';'=10
'{'=11
'}'=12
'fn'=13
'->'=14
'def'=15
':'=16
'Tensor'=17
'_'=18
'v0.0.2'=19
'*'=23
'/'=24
'+'=25
'-'=26
'<'=27
'>'=28
'<='=29
'>='=30
'=='=31
'!='=32
'mut'=36
......@@ -176,7 +176,7 @@ class RelayParser ( Parser ):
literalNames = [ "<INVALID>", "'('", "')'", "','", "'['", "']'", "'if'",
"'else'", "'let'", "'='", "';'", "'{'", "'}'", "'fn'",
"'->'", "'def'", "':'", "'Tensor'", "'_'", "'v0.0.2'",
"'->'", "'def'", "':'", "'Tensor'", "'_'", "'v0.0.3'",
"<INVALID>", "<INVALID>", "<INVALID>", "'*'", "'/'",
"'+'", "'-'", "'<'", "'>'", "'<='", "'>='", "'=='",
"'!='", "<INVALID>", "<INVALID>", "<INVALID>", "'mut'" ]
......
......@@ -16,7 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/
v0.0.2
v0.0.3
def @id[a](%x: a) -> a {
%x
......
......@@ -41,7 +41,7 @@ class AlphaEqualHandler:
public PatternFunctor<bool(const Pattern&, const Pattern&)> {
public:
explicit AlphaEqualHandler(bool map_free_var)
: map_free_var_(map_free_var) {}
: map_free_var_(map_free_var) { }
/*!
* Check equality of two nodes.
......@@ -60,6 +60,19 @@ class AlphaEqualHandler:
if (!rhs->derived_from<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<ModuleNode>()) {
auto rhsm = rhs.as<ModuleNode>();
if (!rhsm) return false;
if (lhsm->functions.size() != rhsm->functions.size()) return false;
for (const auto& p : lhsm->functions) {
if (!Equal(p.second, rhsm->Lookup(p.first->name_hint))) return false;
}
if (lhsm->type_definitions.size() != rhsm->type_definitions.size()) return false;
for (const auto& p : lhsm->type_definitions) {
if (!Equal(p.second, rhsm->LookupDef(p.first->var->name_hint))) return false;
}
return true;
}
return AttrEqual(lhs, rhs);
}
......@@ -70,6 +83,17 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool AttrEqual(const NodeRef& lhs, const NodeRef& rhs) {
if (&lhs == &rhs) return true;
auto lhsd = lhs.as<DictAttrsNode>();
if (lhsd) {
auto rhsd = lhs.as<DictAttrsNode>();
if (!rhsd) return false;
if (lhsd->dict.size() != rhsd->dict.size()) return false;
for (const auto& k : lhsd->dict) {
if (!Equal(k.second, rhsd->dict[k.first])) return false;
}
return true;
}
return AttrsEqualHandler::Equal(lhs, rhs);
}
/*!
......@@ -334,6 +358,7 @@ class AlphaEqualHandler:
}
// check return types.
if (!TypeEqual(lhs->ret_type, rhs->ret_type)) return false;
if (!AttrEqual(lhs->attrs, rhs->attrs)) return false;
return ExprEqual(lhs->body, rhs->body);
} else {
return false;
......@@ -490,7 +515,7 @@ class AlphaEqualHandler:
private:
// whether to map open terms.
bool map_free_var_{false};
bool map_free_var_;
// renaming of NodeRef to indicate two nodes equals to each other
std::unordered_map<NodeRef, NodeRef, NodeHash, NodeEqual> equal_map_;
};
......@@ -506,17 +531,18 @@ bool AlphaEqual(const Expr& lhs, const Expr& rhs) {
// TODO(@jroesch): move to correct namespace?
TVM_REGISTER_API("relay._make._alpha_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(false).Equal(a, b);
});
return AlphaEqualHandler(false).Equal(a, b);
});
TVM_REGISTER_API("relay._make._type_alpha_equal")
.set_body_typed<bool(Type, Type)>([](Type a, Type b) {
return AlphaEqualHandler(false).TypeEqual(a, b);
});
return AlphaEqualHandler(false).TypeEqual(a, b);
});
TVM_REGISTER_API("relay._make._graph_equal")
.set_body_typed<bool(NodeRef, NodeRef)>([](NodeRef a, NodeRef b) {
return AlphaEqualHandler(true).Equal(a, b);
});
return AlphaEqualHandler(true).Equal(a, b);
});
} // namespace relay
} // namespace tvm
......@@ -60,6 +60,11 @@ Doc& Doc::operator<<(const std::string& right) {
return *this << Doc(right);
}
Doc& Doc::operator<<(const DocAtom& right) {
this->stream_.push_back(right);
return *this;
}
Doc Indent(int indent, const Doc& doc) {
Doc ret;
for (auto atom : doc.stream_) {
......@@ -113,5 +118,10 @@ Doc PrintString(const std::string& value) {
return doc << "\"" << value << "\"";
}
Doc PrintNewLine(int ident) {
Doc doc;
return doc << Line(ident);
}
} // namespace relay
} // namespace tvm
......@@ -60,11 +60,17 @@ class Doc {
public:
Doc() {}
explicit Doc(const std::string& str);
template<typename T>
explicit Doc(const T& str) {
(*this) << str;
}
// Append right to this.
Doc& operator<<(const Doc& right);
// Like above, but automatically lifts string to a Doc.
// Like above.
Doc& operator<<(const std::string& right);
// Like above.
Doc& operator<<(const DocAtom& right);
// Like above, but converts right to a string first.
template<typename T>
Doc& operator<<(const T& right) {
......@@ -93,6 +99,8 @@ Doc PrintBool(bool value);
Doc PrintDType(DataType dtype);
// Print a string.
Doc PrintString(const std::string& value);
// Print a newline.
Doc PrintNewLine(int indent = 0);
/*!
* \brief special method to print out const scalar
* \param dtype The data type
......@@ -106,7 +114,7 @@ Doc PrintConstScalar(DataType dtype, const T* data) {
} else if (dtype == Float(32)) {
os << data[0] << 'f';
} else if (dtype == Bool()) {
return PrintBool(data[0] != 0);
return PrintBool(data[0] != 0);
} else {
os << dtype << "(" << data[0] << ")";
}
......
......@@ -130,6 +130,8 @@ Function FunctionNode::make(tvm::Array<Var> params,
tvm::Array<TypeVar> type_params,
tvm::Attrs attrs) {
NodePtr<FunctionNode> n = make_node<FunctionNode>();
CHECK(params.defined());
CHECK(type_params.defined());
n->params = std::move(params);
n->body = std::move(body);
n->ret_type = std::move(ret_type);
......
......@@ -215,7 +215,8 @@ class PrettyPrinter :
return doc;
}
Doc PrintAttrs(const Attrs& attrs, const Expr& op);
std::vector<Doc> PrintCallAttrs(const Attrs& attrs, const Expr& op);
std::vector<Doc> PrintFuncAttrs(const Attrs& attrs);
Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
if (node.as_derived<ExprNode>()) {
......@@ -381,7 +382,7 @@ class PrettyPrinter :
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << "\n";
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << PrintNewLine();
return temp_var;
}
}
......@@ -445,7 +446,13 @@ class PrettyPrinter :
Doc VisitExpr_(const LetNode* op) final {
Doc doc;
doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n";
doc
<< "let "
<< AllocVar(op->var)
<< " = "
<< Print(op->value, false, true)
<< ";"
<< PrintNewLine();
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body);
......@@ -469,8 +476,10 @@ class PrettyPrinter :
for (Var param : fn->params) {
params.push_back(AllocVar(param));
}
doc << PrintVec(params) << PrintAttrs(fn->attrs, fn);
doc << ") ";
for (const Doc& d : PrintFuncAttrs(fn->attrs)) {
params.push_back(d);
}
doc << PrintVec(params) << ") ";
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}
......@@ -512,11 +521,14 @@ class PrettyPrinter :
// visit args first so they are lifted before the op
// this places op closer to its call site
std::vector<Doc> args;
for (Expr arg : op->args) {
for (const Expr& arg : op->args) {
args.push_back(Print(arg));
}
for (const Doc& d : PrintCallAttrs(op->attrs, op->op)) {
args.push_back(d);
}
doc << Print(op->op);
return doc << "(" << PrintVec(args) << PrintAttrs(op->attrs, op->op) << ")";
return doc << "(" << PrintVec(args) << ")";
}
Doc VisitExpr_(const RefCreateNode* op) final {
......@@ -747,40 +759,41 @@ class PrettyPrinter :
*/
class PrettyPrinter::AttrPrinter : public AttrVisitor {
public:
AttrPrinter(Doc& doc, PrettyPrinter* parent) : doc_(doc), parent_(parent) {}
AttrPrinter(std::vector<Doc>* doc, PrettyPrinter* parent) : docs(doc), parent_(parent) {}
template<typename T>
Doc PrintKV(const char* key, const T& value) {
void PrintKV(const char* key, const T& value) {
Doc doc;
return doc << ", " << key << "=" << value;
doc << key << "=" << value;
docs->push_back(doc);
}
void Visit(const char* key, double* value) final {
doc_ << PrintKV(key, value[0]);
PrintKV(key, *value);
}
void Visit(const char* key, int64_t* value) final {
doc_ << PrintKV(key, value[0]);
PrintKV(key, *value);
}
void Visit(const char* key, uint64_t* value) final {
doc_ << PrintKV(key, value[0]);
PrintKV(key, *value);
}
void Visit(const char* key, int* value) final {
doc_ << PrintKV(key, value[0]);
PrintKV(key, *value);
}
void Visit(const char* key, bool* value) final {
doc_ << PrintKV(key, PrintBool(value[0]));
PrintKV(key, PrintBool(*value));
}
void Visit(const char* key, std::string* value) final {
doc_ << PrintKV(key, PrintString(value[0]));
PrintKV(key, PrintString(*value));
}
void Visit(const char* key, void** value) final {
LOG(FATAL) << "do not allow void as argument";
}
void Visit(const char* key, DataType* value) final {
doc_ << PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(value[0]))));
PrintKV(key, PrintString(runtime::TVMType2String(Type2TVMType(*value))));
}
void Visit(const char* key, NodeRef* value) final {
doc_ << PrintKV(key, parent_->PrintAttr(value[0]));
PrintKV(key, parent_->PrintAttr(*value));
}
void Visit(const char* key, runtime::NDArray* value) final {
LOG(FATAL) << "do not allow NDarray as argument";
......@@ -790,29 +803,45 @@ class PrettyPrinter::AttrPrinter : public AttrVisitor {
}
private:
Doc& doc_;
std::vector<Doc>* docs;
PrettyPrinter* parent_;
};
Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {
Doc doc;
if (!attrs.defined()) return doc;
std::vector<Doc> PrettyPrinter::PrintCallAttrs(const Attrs& attrs, const Expr& op) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* op_node = op.as<OpNode>();
if (op_node && (attrs->type_index() != op_node->attrs_type_index)) {
// fallback
return doc << ", " << meta_.GetMetaNode(attrs);
Doc doc;
doc << meta_.GetMetaNode(attrs);
docs.push_back(doc);
return docs;
} else {
AttrPrinter printer(doc, this);
AttrPrinter printer(&docs, this);
const_cast<BaseAttrsNode*>(attrs.operator->())->VisitNonDefaultAttrs(&printer);
return doc;
return docs;
}
}
std::vector<Doc> PrettyPrinter::PrintFuncAttrs(const Attrs& attrs) {
std::vector<Doc> docs;
if (!attrs.defined()) return docs;
const auto* dict_attrs = attrs.as<DictAttrsNode>();
CHECK(dict_attrs);
for (const auto& k : dict_attrs->dict) {
Doc doc;
doc << k.first << "=" << Print(k.second);
docs.push_back(doc);
}
return docs;
}
std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.1" << "\n"
doc << "v0.0.3" << "\n"
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}
......
......@@ -23,7 +23,7 @@ from typing import Union
from functools import wraps
raises_parse_error = raises(tvm._ffi.base.TVMError)
SEMVER = "v0.0.2"
SEMVER = "v0.0.3"
BINARY_OPS = {
"*": relay.multiply,
......@@ -60,8 +60,19 @@ TYPES = {
"float16x4",
}
def assert_alpha_equal(a, b):
if not alpha_equal(a, b):
raise Exception("lhs is: ", str(a), "rhs is: ", str(b))
def roundtrip(expr):
assert_alpha_equal(relay.fromtext(str(expr)), expr)
def parse_text(code):
return relay.fromtext(SEMVER + "\n" + code)
x = relay.fromtext(SEMVER + "\n" + code)
roundtrip(x)
return x
def parses_as(code, expr):
# type: (str, relay.Expr) -> bool
......@@ -114,20 +125,20 @@ def test_int_literal():
def test_float_literal():
assert get_scalar(parse_text("1.0")) == 1.0
assert isclose(get_scalar(parse_text("1.56667")), 1.56667)
assert get_scalar(parse_text("0.0")) == 0.0
assert get_scalar(parse_text("-10.0")) == -10.0
assert get_scalar(parse_text("1.0f")) == 1.0
assert isclose(get_scalar(parse_text("1.56667f")), 1.56667)
assert get_scalar(parse_text("0.0f")) == 0.0
assert get_scalar(parse_text("-10.0f")) == -10.0
# scientific notation
assert isclose(get_scalar(parse_text("1e-1")), 1e-1)
assert get_scalar(parse_text("1e+1")) == 1e+1
assert isclose(get_scalar(parse_text("1E-1")), 1E-1)
assert get_scalar(parse_text("1E+1")) == 1E+1
assert isclose(get_scalar(parse_text("1.0e-1")), 1.0e-1)
assert get_scalar(parse_text("1.0e+1")) == 1.0e+1
assert isclose(get_scalar(parse_text("1.0E-1")), 1.0E-1)
assert get_scalar(parse_text("1.0E+1")) == 1.0E+1
assert isclose(get_scalar(parse_text("1e-1f")), 1e-1)
assert get_scalar(parse_text("1e+1f")) == 1e+1
assert isclose(get_scalar(parse_text("1E-1f")), 1E-1)
assert get_scalar(parse_text("1E+1f")) == 1E+1
assert isclose(get_scalar(parse_text("1.0e-1f")), 1.0e-1)
assert get_scalar(parse_text("1.0e+1f")) == 1.0e+1
assert isclose(get_scalar(parse_text("1.0E-1f")), 1.0E-1)
assert get_scalar(parse_text("1.0E+1f")) == 1.0E+1
def test_bool_literal():
......@@ -163,7 +174,7 @@ def test_op_assoc():
def test_vars():
# temp vars won't work b/c they start with a digit
# # temp var
# temp_var = relay.fromtext("%1")
# temp_var = parse_text("%1")
# assert isinstance(temp_var, relay.Var)
# assert temp_var.name == "1"
......@@ -321,8 +332,7 @@ def test_func():
# TODO(@jmp): Crashes if %x isn't annnotated.
def test_defn():
id_defn = relay.fromtext(
SEMVER+
id_defn = parse_text(
"""
def @id(%x: int32) -> int32 {
%x
......@@ -332,8 +342,7 @@ def test_defn():
def test_recursive_call():
id_defn = relay.fromtext(
SEMVER+
id_defn = parse_text(
"""
def @id(%x: int32) -> int32 {
@id(%x)
......@@ -361,8 +370,7 @@ def test_ifelse():
@raises_parse_error
def test_ifelse_scope():
relay.fromtext(
SEMVER+
parse_text(
"""
if (True) {
let %x = ();
......@@ -616,3 +624,27 @@ def test_tuple_type():
UNIT
)
)
if __name__ == "__main__":
test_comments()
test_int_literal()
test_float_literal()
test_bool_literal()
test_negative()
test_bin_op()
test_parens()
test_op_assoc()
test_let()
test_seq()
test_graph()
test_tuple()
test_func()
test_defn()
test_recursive_call()
test_ifelse()
test_call()
test_incomplete_type()
test_builtin_types()
test_tensor_type()
test_function_type()
test_tuple_type()
......@@ -21,7 +21,7 @@ from tvm import relay
do_print = [False]
SEMVER = "v0.0.1\n"
SEMVER = "v0.0.3\n"
def show(text):
if do_print[0]:
......@@ -175,23 +175,23 @@ def test_call_node_order():
assert relay.Call(relay.Function([x], x), [relay.Call(relay.Function([y], y), [relay.const(1)])]).astext() == SEMVER + \
("%0 = fn (%y) {\n"
" %y\n"
"}\n"
"%1 = %0(1)\n"
"};\n"
"%1 = %0(1);\n"
"%2 = fn (%x) {\n"
" %x\n"
"}\n"
"};\n"
"%2(%1)")
def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x")
assert relay.Let(x, tup, tup).astext() == SEMVER + \
("%0 = (0, 0)\n"
"let %x = %0\n"
("%0 = (0, 0);\n"
"let %x = %0;\n"
"%0")
assert relay.Let(x, tup, x).astext() == SEMVER + \
("let %x = (0, 0)\n"
("let %x = (0, 0);\n"
"%x")
if __name__ == "__main__":
......
......@@ -160,7 +160,6 @@ def test_type_relation_alpha_equal():
broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
identity = tvm.get_env_func("tvm.relay.type_relation.Identity")
# attrs are also compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
......@@ -391,7 +390,6 @@ def test_call_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.Var("v2")
# attrs are compared only by pointer equality
attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment