Commit d274e4b3 by Jared Roesch Committed by Tianqi Chen

[Relay][Parser] Improve Relay parser and pretty printing, including CMAKE (#2377)

parent d0f83664
if(USE_ANTLR)
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar)
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar")
file(GLOB_RECURSE ANTLR4
/usr/local/lib/antlr-*-complete.jar
/usr/local/Cellar/*antlr-*-complete.jar)
if(DEFINED ANTLR4)
# Get the first element of the list of antlr jars.
# Sort and reverse the list so the item selected is the highest
# version in lib or else in Cellar if no lib installation exists.
list(SORT ANTLR4)
list(REVERSE ANTLR4)
list(GET ANTLR4 0 ANTLR4)
set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
......@@ -14,15 +22,21 @@ if(USE_ANTLR)
${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py)
set(JAVA_HOME $ENV{JAVA_HOME})
if (NOT DEFINED JAVA_HOME)
# Hack to get system to search for Java itself.
set(JAVA_HOME "/usr")
endif()
# Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR})
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
else()
message(FATAL_ERROR "Can't find ANTLR4!")
message(FATAL_ERROR "Can't find ANTLR4: ANTLR4=" ${ANTLR4})
endif()
endif(USE_ANTLR)
......@@ -108,7 +108,9 @@ class SourceName : public NodeRef {
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SourceNameNode* operator->() const;
inline const SourceNameNode* operator->() const {
return static_cast<SourceNameNode*>(this->node_.get());
}
/*!
* \brief Get an SourceName for a given operator name.
......
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._base", __name__)
......@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
from . import _expr
from . import _base
NodeBase = NodeBase
......@@ -63,6 +64,9 @@ class RelayNode(NodeBase):
"""
return _expr.RelayPrint(self, show_meta_data, annotate)
def set_span(self, span):
_base.set_span(self, span)
@register_relay_node
class Span(RelayNode):
......@@ -71,6 +75,12 @@ class Span(RelayNode):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
@register_relay_node
class SourceName(RelayNode):
"""A identifier for a source location"""
def __init__(self, name):
self.__init_handle_by_constructor__(_make.SourceName, name)
@register_relay_node
class Id(NodeBase):
......
grammar Relay;
SEMVER: 'v0.0.1' ;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
......@@ -20,7 +22,8 @@ NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
......@@ -31,13 +34,13 @@ BOOL_LIT
// non-negative floats
FLOAT
: INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| INT EXP // 1e10 3e4
: NAT '.' NAT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| NAT EXP // 1e10 3e4
;
// non-negative ints
INT: DIGIT+ ;
fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...]
NAT: DIGIT+ ;
fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
......@@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: (defn* | expr) EOF ;
prog: SEMVER (defn* | expr) EOF ;
// option: 'set' ident BOOL_LIT ;
......@@ -73,10 +76,11 @@ expr
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr # seq
| 'let' MUT? var '=' '{' expr '}' ';' expr # seq
| 'let' MUT? var '=' expr ';' expr # let
| 'let' MUT? var '=' '{' expr '}' ';' expr # let
// sugar for let %_ = expr; expr
| expr ';' expr # seq
| expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update
// | ident '=' expr # writeRef
......@@ -84,16 +88,25 @@ expr
| ident # identExpr
| scalar # scalarExpr
// | expr '.' INT # project
// | 'debug' # debug
// | expr '.' NAT # project
// | 'debug' # debug
;
func: 'fn' varList ('->' type_)? body ;
defn: 'def' ident varList ('->' type_)? body ;
func: 'fn' '(' argList ')' ('->' type_)? body ;
defn: 'def' ident '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
;
varList: '(' (var (',' var)*)? ')' ;
varList: (var (',' var)*)? ;
var: ident (':' type_)? ;
attrList: (attr (',' attr)*)? ;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
......@@ -110,7 +123,7 @@ type_
// | identType '[' (type_ (',' type_)*)? ']' # callType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| INT # intType
| NAT # intType
;
shapeSeq
......@@ -123,20 +136,20 @@ shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
| INT # intShape
| NAT # intShape
;
identType: CNAME ;
// Int8, Int16, Int32, Int64
// UInt8, UInt16, UInt32, UInt64
// Float16, Float32, Float64
// Bool
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
// bool
body: '{' expr '}' ;
scalar
: FLOAT # scalarFloat
| INT # scalarInt
| NAT # scalarInt
| BOOL_LIT # scalarBool
;
......@@ -144,4 +157,5 @@ ident
: opIdent
| GLOBAL_VAR
| LOCAL_VAR
| GRAPH_VAR
;
"""A parser for Relay's text format."""
from __future__ import absolute_import
from .. import register_func
def enabled():
"""Is the parser enabled/Can we import the parser?"""
"""Checks whether the parser is enabled, this allows users to
optionally support building the parser.
We use this check before importing the parser.
"""
try:
# pylint: disable=unused-variable
from tvm.relay import _parser
......@@ -11,7 +16,8 @@ def enabled():
except Exception:
return False
def fromtext(data):
@register_func("relay.fromtext")
def fromtext(data, source_name=None):
"""Parse a Relay program."""
from tvm.relay import _parser
return _parser.fromtext(data)
return _parser.fromtext(data, source_name)
......@@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name));
}
TVM_REGISTER_API("relay._make.SourceName")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SourceName::Get(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
p->stream << "SourceName(" << node->name << ", " << node << ")";
......@@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef node_ref = args[0];
auto rn = node_ref.as_derived<RelayNode>();
CHECK(rn);
Span sp = args[1];
rn->span = sp;
});
} // namespace relay
} // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment