Commit d274e4b3 by Jared Roesch Committed by Tianqi Chen

[Relay][Parser] Improve Relay parser and pretty printing, including CMAKE (#2377)

parent d0f83664
if(USE_ANTLR) if(USE_ANTLR)
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar) file(GLOB_RECURSE ANTLR4
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar") /usr/local/lib/antlr-*-complete.jar
/usr/local/Cellar/*antlr-*-complete.jar)
if(DEFINED ANTLR4)
# Get the first element of the list of antlr jars.
# Sort and reverse the list so the item selected is the highest
# version in lib or else in Cellar if no lib installation exists.
list(SORT ANTLR4)
list(REVERSE ANTLR4)
list(GET ANTLR4 0 ANTLR4)
set(RELAY_PARSER_DIR set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar) ${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
...@@ -14,15 +22,21 @@ if(USE_ANTLR) ...@@ -14,15 +22,21 @@ if(USE_ANTLR)
${RELAY_PARSER_DIR}/py3/RelayParser.py ${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py) ${RELAY_PARSER_DIR}/py3/RelayLexer.py)
set(JAVA_HOME $ENV{JAVA_HOME})
if (NOT DEFINED JAVA_HOME)
# Hack to get system to search for Java itself.
set(JAVA_HOME "/usr")
endif()
# Generate ANTLR grammar for parsing. # Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER} add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2 COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3 COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4 DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR}) WORKING_DIRECTORY ${RELAY_PARSER_DIR})
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER}) add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
else() else()
message(FATAL_ERROR "Can't find ANTLR4!") message(FATAL_ERROR "Can't find ANTLR4: ANTLR4=" ${ANTLR4})
endif() endif()
endif(USE_ANTLR) endif(USE_ANTLR)
...@@ -108,7 +108,9 @@ class SourceName : public NodeRef { ...@@ -108,7 +108,9 @@ class SourceName : public NodeRef {
* \brief access the internal node container * \brief access the internal node container
* \return the pointer to the internal node container * \return the pointer to the internal node container
*/ */
inline const SourceNameNode* operator->() const; inline const SourceNameNode* operator->() const {
return static_cast<SourceNameNode*>(this->node_.get());
}
/*! /*!
* \brief Get an SourceName for a given operator name. * \brief Get an SourceName for a given operator name.
......
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._base", __name__)
...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs ...@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make from . import _make
from . import _expr from . import _expr
from . import _base
NodeBase = NodeBase NodeBase = NodeBase
...@@ -63,6 +64,9 @@ class RelayNode(NodeBase): ...@@ -63,6 +64,9 @@ class RelayNode(NodeBase):
""" """
return _expr.RelayPrint(self, show_meta_data, annotate) return _expr.RelayPrint(self, show_meta_data, annotate)
def set_span(self, span):
_base.set_span(self, span)
@register_relay_node @register_relay_node
class Span(RelayNode): class Span(RelayNode):
...@@ -71,6 +75,12 @@ class Span(RelayNode): ...@@ -71,6 +75,12 @@ class Span(RelayNode):
def __init__(self, source, lineno, col_offset): def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset) self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
@register_relay_node
class SourceName(RelayNode):
"""A identifier for a source location"""
def __init__(self, name):
self.__init_handle_by_constructor__(_make.SourceName, name)
@register_relay_node @register_relay_node
class Id(NodeBase): class Id(NodeBase):
......
grammar Relay; grammar Relay;
SEMVER: 'v0.0.1' ;
// Lexing // Lexing
// comments // comments
WS : [ \t\n\r]+ -> skip ; WS : [ \t\n\r]+ -> skip ;
...@@ -20,7 +22,8 @@ NE: '!=' ; ...@@ -20,7 +22,8 @@ NE: '!=' ;
opIdent: CNAME ; opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ; GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ; LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ; MUT: 'mut' ;
...@@ -31,13 +34,13 @@ BOOL_LIT ...@@ -31,13 +34,13 @@ BOOL_LIT
// non-negative floats // non-negative floats
FLOAT FLOAT
: INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5 : NAT '.' NAT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| INT EXP // 1e10 3e4 | NAT EXP // 1e10 3e4
; ;
// non-negative ints // non-negative ints
INT: DIGIT+ ; NAT: DIGIT+ ;
fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...] fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ; CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ; fragment LETTER: [a-zA-Z] ;
...@@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ; ...@@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ;
// Parsing // Parsing
// A Relay program is a list of global definitions or an expression. // A Relay program is a list of global definitions or an expression.
prog: (defn* | expr) EOF ; prog: SEMVER (defn* | expr) EOF ;
// option: 'set' ident BOOL_LIT ; // option: 'set' ident BOOL_LIT ;
...@@ -73,10 +76,11 @@ expr ...@@ -73,10 +76,11 @@ expr
| 'if' '(' expr ')' body 'else' body # ifElse | 'if' '(' expr ')' body 'else' body # ifElse
// sequencing // sequencing
| 'let' MUT? var '=' expr ';' expr # seq | 'let' MUT? var '=' expr ';' expr # let
| 'let' MUT? var '=' '{' expr '}' ';' expr # seq | 'let' MUT? var '=' '{' expr '}' ';' expr # let
// sugar for let %_ = expr; expr // sugar for let %_ = expr; expr
| expr ';' expr # seq | expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update // mutable update
// | ident '=' expr # writeRef // | ident '=' expr # writeRef
...@@ -84,16 +88,25 @@ expr ...@@ -84,16 +88,25 @@ expr
| ident # identExpr | ident # identExpr
| scalar # scalarExpr | scalar # scalarExpr
// | expr '.' INT # project // | expr '.' NAT # project
// | 'debug' # debug // | 'debug' # debug
; ;
func: 'fn' varList ('->' type_)? body ; func: 'fn' '(' argList ')' ('->' type_)? body ;
defn: 'def' ident varList ('->' type_)? body ; defn: 'def' ident '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
;
varList: '(' (var (',' var)*)? ')' ; varList: (var (',' var)*)? ;
var: ident (':' type_)? ; var: ident (':' type_)? ;
attrList: (attr (',' attr)*)? ;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations // TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ; // returnAnno: (ident ':')? type_ ;
...@@ -110,7 +123,7 @@ type_ ...@@ -110,7 +123,7 @@ type_
// | identType '[' (type_ (',' type_)*)? ']' # callType // | identType '[' (type_ (',' type_)*)? ']' # callType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType | 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType | '_' # incompleteType
| INT # intType | NAT # intType
; ;
shapeSeq shapeSeq
...@@ -123,20 +136,20 @@ shape ...@@ -123,20 +136,20 @@ shape
: '(' shape ')' # parensShape : '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType // | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType // | type_ op=('+'|'-') type_ # binOpType
| INT # intShape | NAT # intShape
; ;
identType: CNAME ; identType: CNAME ;
// Int8, Int16, Int32, Int64 // int8, int16, int32, int64
// UInt8, UInt16, UInt32, UInt64 // uint8, uint16, uint32, uint64
// Float16, Float32, Float64 // float16, float32, float64
// Bool // bool
body: '{' expr '}' ; body: '{' expr '}' ;
scalar scalar
: FLOAT # scalarFloat : FLOAT # scalarFloat
| INT # scalarInt | NAT # scalarInt
| BOOL_LIT # scalarBool | BOOL_LIT # scalarBool
; ;
...@@ -144,4 +157,5 @@ ident ...@@ -144,4 +157,5 @@ ident
: opIdent : opIdent
| GLOBAL_VAR | GLOBAL_VAR
| LOCAL_VAR | LOCAL_VAR
| GRAPH_VAR
; ;
"""A parser for Relay's text format.""" """A parser for Relay's text format."""
from __future__ import absolute_import from __future__ import absolute_import
from .. import register_func
def enabled(): def enabled():
"""Is the parser enabled/Can we import the parser?""" """Checks whether the parser is enabled, this allows users to
optionally support building the parser.
We use this check before importing the parser.
"""
try: try:
# pylint: disable=unused-variable # pylint: disable=unused-variable
from tvm.relay import _parser from tvm.relay import _parser
...@@ -11,7 +16,8 @@ def enabled(): ...@@ -11,7 +16,8 @@ def enabled():
except Exception: except Exception:
return False return False
def fromtext(data): @register_func("relay.fromtext")
def fromtext(data, source_name=None):
"""Parse a Relay program.""" """Parse a Relay program."""
from tvm.relay import _parser from tvm.relay import _parser
return _parser.fromtext(data) return _parser.fromtext(data, source_name)
...@@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) { ...@@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name)); return SourceName(GetSourceNameNode(name));
} }
TVM_REGISTER_API("relay._make.SourceName")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SourceName::Get(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) { .set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
p->stream << "SourceName(" << node->name << ", " << node << ")"; p->stream << "SourceName(" << node->name << ", " << node << ")";
...@@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) ...@@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IdNode); TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef node_ref = args[0];
auto rn = node_ref.as_derived<RelayNode>();
CHECK(rn);
Span sp = args[1];
rn->span = sp;
});
} // namespace relay } // namespace relay
} // namespace tvm } // namespace tvm
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment