Commit d3bc59d2 by Josh Pollock Committed by Tianqi Chen

[Relay][RFC] Relay IR Text Format (#1781)

parent eb6d64f1
...@@ -209,3 +209,7 @@ tvm_t.* ...@@ -209,3 +209,7 @@ tvm_t.*
# patch sentinel # patch sentinel
patched.txt patched.txt
# Python type checking
.mypy_cache/
.pyre/
...@@ -47,6 +47,7 @@ tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF) ...@@ -47,6 +47,7 @@ tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" OFF) tvm_option(USE_SORT "Build with sort support" OFF)
tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
# include directories # include directories
include_directories("include") include_directories("include")
...@@ -183,6 +184,7 @@ include(cmake/modules/Metal.cmake) ...@@ -183,6 +184,7 @@ include(cmake/modules/Metal.cmake)
include(cmake/modules/ROCM.cmake) include(cmake/modules/ROCM.cmake)
include(cmake/modules/SGX.cmake) include(cmake/modules/SGX.cmake)
include(cmake/modules/LLVM.cmake) include(cmake/modules/LLVM.cmake)
include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake) include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/Sort.cmake)
......
...@@ -98,6 +98,7 @@ stage('Build') { ...@@ -98,6 +98,7 @@ stage('Build') {
echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake
echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_ANTLR ON\\) >> config.cmake
echo set\\(USE_BLAS openblas\\) >> config.cmake echo set\\(USE_BLAS openblas\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
...@@ -133,6 +134,7 @@ stage('Build') { ...@@ -133,6 +134,7 @@ stage('Build') {
echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake
echo set\\(USE_NNPACK ON\\) >> config.cmake echo set\\(USE_NNPACK ON\\) >> config.cmake
echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake
echo set\\(USE_ANTLR ON\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
""" """
......
...@@ -128,3 +128,6 @@ set(USE_ROCBLAS OFF) ...@@ -128,3 +128,6 @@ set(USE_ROCBLAS OFF)
# Whether use contrib sort # Whether use contrib sort
set(USE_SORT OFF) set(USE_SORT OFF)
# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)
if(USE_ANTLR)
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar)
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar")
set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
set(RELAY_PARSER
${RELAY_PARSER_DIR}/py2/RelayVisitor.py
${RELAY_PARSER_DIR}/py2/RelayParser.py
${RELAY_PARSER_DIR}/py2/RelayLexer.py
${RELAY_PARSER_DIR}/py3/RelayVisitor.py
${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py)
# Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR})
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
else()
message(FATAL_ERROR "Can't find ANTLR4!")
endif()
endif(USE_ANTLR)
...@@ -40,10 +40,3 @@ COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh ...@@ -40,10 +40,3 @@ COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh
RUN bash /install/ubuntu_install_nnpack.sh RUN bash /install/ubuntu_install_nnpack.sh
ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin
# ANTLR deps
COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh
COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh
RUN bash /install/ubuntu_install_antlr.sh
cd /usr/local/lib cd /usr/local/lib
wget https://www.antlr.org/download/antlr-4.7.1-complete.jar wget https://www.antlr.org/download/antlr-4.7.1-complete.jar
cd - cd -
alias antlr4='java -jar /usr/local/lib/antlr-4.7.1-complete.jar'
...@@ -8,6 +8,7 @@ from . import expr ...@@ -8,6 +8,7 @@ from . import expr
from . import module from . import module
from . import ir_pass from . import ir_pass
from .build_module import build, build_config, create_executor from .build_module import build, build_config, create_executor
from . import parser
# Root operators # Root operators
from .op import Op from .op import Op
...@@ -52,7 +53,6 @@ Let = expr.Let ...@@ -52,7 +53,6 @@ Let = expr.Let
If = expr.If If = expr.If
TupleGetItem = expr.TupleGetItem TupleGetItem = expr.TupleGetItem
# helper functions # helper functions
var = expr.var var = expr.var
const = expr.const const = expr.const
...@@ -63,3 +63,6 @@ bind = expr.bind ...@@ -63,3 +63,6 @@ bind = expr.bind
def _debug(*args): def _debug(*args):
import pdb import pdb
pdb.set_trace() pdb.set_trace()
# Parser
fromtext = parser.fromtext
...@@ -22,7 +22,7 @@ class Constant(Expr): ...@@ -22,7 +22,7 @@ class Constant(Expr):
class Tuple(Expr): class Tuple(Expr):
fields = .. # type: List[Expr] fields = ... # type: List[Expr]
def __init__(self, fields): def __init__(self, fields):
# type: (List[Expr]) -> None # type: (List[Expr]) -> None
...@@ -77,10 +77,10 @@ class Call(Expr): ...@@ -77,10 +77,10 @@ class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details.""" """A function call in Relay, see tvm/relay/expr.h for more details."""
op = ... # type: Expr op = ... # type: Expr
args = ... # type: List[Expr] args = ... # type: List[Expr]
# todo(@jroesch): add attrs # todo(@jroesch): add attrs. revise attrs type in __init__
def __init__(self, op, args, attrs, ty_args=None): def __init__(self, op, args, attrs=None, ty_args=None):
# type: (Expr, List[Expr], Optional[List[Type]]) -> None # type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None
if not ty_args: if not ty_args:
ty_args = [] ty_args = []
......
grammar Relay;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
LINE_COMMENT : '//' .*? '\n' -> skip ;
COMMENT : '/*' .*? '*/' -> skip ;
// operators
MUL: '*' ;
DIV: '/' ;
ADD: '+' ;
SUB: '-' ;
LT: '<' ;
GT: '>' ;
LE: '<=' ;
GE: '>=' ;
EQ: '==' ;
NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ;
MUT: 'mut' ;
BOOL_LIT
: 'True'
| 'False'
;
// non-negative floats
FLOAT
: INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| INT EXP // 1e10 3e4
;
// non-negative ints
INT: DIGIT+ ;
fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
fragment DIGIT: [0-9] ;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: (defn* | expr) EOF ;
// option: 'set' ident BOOL_LIT ;
expr
// operators
: '(' expr ')' # parens
| '-' expr # neg
| expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp
// function definition and application
| expr '(' (expr (',' expr)*)? ')' # call
| func # funcExpr
// tuples and tensors
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr # seq
| 'let' MUT? var '=' '{' expr '}' ';' expr # seq
// sugar for let %_ = expr; expr
| expr ';' expr # seq
// mutable update
// | ident '=' expr # writeRef
// | expr '^' # readRef
| ident # identExpr
| scalar # scalarExpr
// | expr '.' INT # project
// | 'debug' # debug
;
func: 'fn' varList ('->' type_)? body ;
defn: 'def' ident varList ('->' type_)? body ;
varList: '(' (var (',' var)*)? ')' ;
var: ident (':' type_)? ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
// relations: 'where' relation (',' relation)* ;
// relation: ident '(' (type_ (',' type_)*)? ')' ;
type_
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| identType # identTypeType
| 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType
// currently unused
// | identType '[' (type_ (',' type_)*)? ']' # callType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| INT # intType
;
shapeSeq
: '(' ')'
| '(' shape ',' ')'
| '(' shape (',' shape)+ ')'
;
shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
| INT # intShape
;
identType: CNAME ;
// Int8, Int16, Int32, Int64
// UInt8, UInt16, UInt32, UInt64
// Float16, Float32, Float64
// Bool
body: '{' expr '}' ;
scalar
: FLOAT # scalarFloat
| INT # scalarInt
| BOOL_LIT # scalarBool
;
ident
: opIdent
| GLOBAL_VAR
| LOCAL_VAR
;
"""A parser for Relay's text format."""
from __future__ import absolute_import
def enabled():
"""Is the parser enabled/Can we import the parser?"""
try:
# pylint: disable=unused-variable
from tvm.relay import _parser
return True
# pylint: disable=broad-except
except Exception:
return False
def fromtext(data):
"""Parse a Relay program."""
from tvm.relay import _parser
return _parser.fromtext(data)
...@@ -156,7 +156,7 @@ class FuncType(Type): ...@@ -156,7 +156,7 @@ class FuncType(Type):
class IncompleteType(Type): class IncompleteType(Type):
"""An incomplete type.""" """An incomplete type."""
def __init__(self, kind): def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind) self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node @register_relay_node
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment