Commit d3bc59d2 by Josh Pollock Committed by Tianqi Chen

[Relay][RFC] Relay IR Text Format (#1781)

parent eb6d64f1
......@@ -209,3 +209,7 @@ tvm_t.*
# patch sentinel
patched.txt
# Python type checking
.mypy_cache/
.pyre/
......@@ -47,6 +47,7 @@ tvm_option(USE_ROCBLAS "Build with ROCM:RoCBLAS" OFF)
tvm_option(USE_SORT "Build with sort support" OFF)
tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
# include directories
include_directories("include")
......@@ -183,6 +184,7 @@ include(cmake/modules/Metal.cmake)
include(cmake/modules/ROCM.cmake)
include(cmake/modules/SGX.cmake)
include(cmake/modules/LLVM.cmake)
include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/Sort.cmake)
......
......@@ -98,6 +98,7 @@ stage('Build') {
echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake
echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_ANTLR ON\\) >> config.cmake
echo set\\(USE_BLAS openblas\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
......@@ -133,6 +134,7 @@ stage('Build') {
echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake
echo set\\(USE_NNPACK ON\\) >> config.cmake
echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake
echo set\\(USE_ANTLR ON\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
"""
......
......@@ -128,3 +128,6 @@ set(USE_ROCBLAS OFF)
# Whether use contrib sort
set(USE_SORT OFF)
# Build ANTLR parser for Relay text format
set(USE_ANTLR OFF)
if(USE_ANTLR)
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar)
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar")
set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
set(RELAY_PARSER
${RELAY_PARSER_DIR}/py2/RelayVisitor.py
${RELAY_PARSER_DIR}/py2/RelayParser.py
${RELAY_PARSER_DIR}/py2/RelayLexer.py
${RELAY_PARSER_DIR}/py3/RelayVisitor.py
${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py)
# Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR})
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
else()
message(FATAL_ERROR "Can't find ANTLR4!")
endif()
endif(USE_ANTLR)
......@@ -40,10 +40,3 @@ COPY install/ubuntu_install_nnpack.sh /install/ubuntu_install_nnpack.sh
RUN bash /install/ubuntu_install_nnpack.sh
ENV PATH $PATH:$CARGO_HOME/bin:/usr/lib/go-1.10/bin
# ANTLR deps
COPY install/ubuntu_install_java.sh /install/ubuntu_install_java.sh
RUN bash /install/ubuntu_install_java.sh
COPY install/ubuntu_install_antlr.sh /install/ubuntu_install_antlr.sh
RUN bash /install/ubuntu_install_antlr.sh
cd /usr/local/lib
wget https://www.antlr.org/download/antlr-4.7.1-complete.jar
cd -
alias antlr4='java -jar /usr/local/lib/antlr-4.7.1-complete.jar'
......@@ -8,6 +8,7 @@ from . import expr
from . import module
from . import ir_pass
from .build_module import build, build_config, create_executor
from . import parser
# Root operators
from .op import Op
......@@ -52,7 +53,6 @@ Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
# helper functions
var = expr.var
const = expr.const
......@@ -63,3 +63,6 @@ bind = expr.bind
def _debug(*args):
import pdb
pdb.set_trace()
# Parser
fromtext = parser.fromtext
# pylint: disable=invalid-name, unused-import
"""A parser for Relay's text format."""
from __future__ import absolute_import
import sys
from collections import deque
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any
from . import module
from . import expr
from . import ty
from . import op
class ParseError(Exception):
"""Exception type for parse errors."""
def __init__(self, message):
# type: (str) -> None
super(ParseError, self).__init__()
self.message = message
PYTHON_VERSION = sys.version_info.major
try:
if PYTHON_VERSION == 2:
from .grammar.py2.RelayVisitor import RelayVisitor
from .grammar.py2.RelayParser import RelayParser
from .grammar.py2.RelayLexer import RelayLexer
else:
from .grammar.py3.RelayVisitor import RelayVisitor
from .grammar.py3.RelayParser import RelayParser
from .grammar.py3.RelayLexer import RelayLexer
except ImportError:
raise ParseError("Couldn't find ANTLR parser. Try building with USE_ANTLR=ON.")
try:
from antlr4 import ParserRuleContext, InputStream, CommonTokenStream
from antlr4.tree.Tree import TerminalNode
except ImportError:
raise ParseError("Couldn't find ANTLR runtime." +
"Try running `pip{} install antlr4-python{}-runtime`."
.format(PYTHON_VERSION, PYTHON_VERSION))
BINARY_OPS = {
RelayParser.MUL: op.multiply,
RelayParser.DIV: op.divide,
RelayParser.ADD: op.add,
RelayParser.SUB: op.subtract,
RelayParser.LT: op.less,
RelayParser.GT: op.greater,
RelayParser.LE: op.less_equal,
RelayParser.GE: op.greater_equal,
RelayParser.EQ: op.equal,
RelayParser.NE: op.not_equal,
}
TYPE_PREFIXES = [
"int",
"uint",
"float",
"bool",
]
T = TypeVar("T")
Scope = Deque[Tuple[str, T]]
Scopes = Deque[Scope[T]]
def lookup(scopes, name):
# type: (Scopes[T], str) -> Optional[T]
"""Look up `name` in `scopes`."""
for scope in scopes:
for key, val in scope:
if key == name:
return val
return None
# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class ParseTreeToRelayIR(RelayVisitor):
"""Parse Relay text format into Relay IR."""
def __init__(self):
# type: () -> None
self.module = module.Module({}) # type: module.Module
# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
super(ParseTreeToRelayIR, self).__init__()
def enter_var_scope(self):
# type: () -> None
"""Enter a new Var scope so it can be popped off later."""
self.var_scopes.appendleft(deque())
def exit_var_scope(self):
# type: () -> Scope[expr.Var]
"""Pop off the current Var scope and return it."""
return self.var_scopes.popleft()
def mk_var(self, name, type_):
# type: (str, ty.Type) -> expr.Var
"""Create a new Var and add it to the Var scope."""
var = expr.Var(name, type_)
self.var_scopes[0].appendleft((name, var))
return var
def enter_type_param_scope(self):
# type: () -> None
"""Enter a new TypeVar scope so it can be popped off later."""
self.type_param_scopes.appendleft(deque())
def exit_type_param_scope(self):
# type: () -> Scope[ty.TypeVar]
"""Pop off the current TypeVar scope and return it."""
return self.type_param_scopes.popleft()
def mk_typ(self, name, kind):
# (str, ty.Kind) -> ty.TypeVar
"""Create a new TypeVar and add it to the TypeVar scope."""
typ = ty.TypeVar(name, kind)
self.type_param_scopes[0].appendleft((name, typ))
return typ
def visitTerminal(self, node):
# type: (TerminalNode) -> Union[expr.Expr, int, float]
"""Visit lexer tokens that aren't ignored or visited by other functions."""
node_type = node.getSymbol().type
node_text = node.getText()
# variables
if node_type == RelayLexer.GLOBAL_VAR:
return expr.GlobalVar(node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR:
name = node_text[1:]
var = lookup(self.var_scopes, name)
if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))
return var
# data types
elif node_type == RelayLexer.INT:
return int(node_text)
elif node_type == RelayLexer.FLOAT:
return float(node_text)
elif node_type == RelayLexer.BOOL_LIT:
if node_text == "True":
return True
elif node_text == "False":
return False
else:
raise ParseError("Unrecognized BOOL_LIT: `{}`".format(node_text))
else:
raise ParseError("todo: {}".format(node_text))
def visit_list(self, ctx_list):
# type: (List[ParserRuleContext]) -> List[Any]
""""Visit a list of contexts."""
return [self.visit(ctx) for ctx in ctx_list]
def getType_(self, ctx):
# type: (Optional[RelayParser.Type_Context]) -> Optional[ty.Type]
"""Return a (possibly None) Relay type."""
if ctx is None:
return None
return self.visit(ctx)
def visitProg(self, ctx):
# type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment]
if ctx.defn():
self.visit_list(ctx.defn())
return self.module
return self.visit(ctx.expr())
# Exprs
def visitOpIdent(self, ctx):
# type: (RelayParser.OpIdentContext) -> op.Op
return op.get(ctx.CNAME().getText())
# pass through
def visitParens(self, ctx):
# type: (RelayParser.ParensContext) -> expr.Expr
return self.visit(ctx.expr())
# pass through
def visitBody(self, ctx):
# type: (RelayParser.BodyContext) -> expr.Expr
return self.visit(ctx.expr())
def visitScalarFloat(self, ctx):
# type: (RelayParser.ScalarFloatContext) -> expr.Constant
return expr.const(self.visit(ctx.FLOAT()))
def visitScalarInt(self, ctx):
# type: (RelayParser.ScalarIntContext) -> expr.Constant
return expr.const(self.visit(ctx.INT()))
def visitScalarBool(self, ctx):
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
return expr.const(self.visit(ctx.BOOL_LIT()))
def visitNeg(self, ctx):
# type: (RelayParser.NegContext) -> Union[expr.Constant, expr.Call]
val = self.visit(ctx.expr())
if isinstance(val, expr.Constant) and val.data.asnumpy().ndim == 0:
# fold Neg in for scalars
return expr.const(-val.data.asnumpy().item())
return op.negative(val)
def visitTuple(self, ctx):
# type: (RelayParser.TupleContext) -> expr.Tuple
tup = self.visit_list(ctx.expr())
return expr.Tuple(tup)
# Currently doesn't support mutable sequencing.
def visitSeq(self, ctx):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if ctx.MUT() is not None:
raise ParseError("Mutation is currently unsupported.")
if ctx.var() is None or ctx.var().ident() is None:
# anonymous identity
ident = "_"
type_ = None
else:
local_var = ctx.var().ident().LOCAL_VAR()
if local_var is None:
raise ParseError('Only local ids may be used in `let`s.')
ident = local_var.getText()[1:]
type_ = self.getType_(ctx.var().type_())
var = self.mk_var(ident, type_)
self.enter_var_scope()
value = self.visit(ctx.expr(0))
self.exit_var_scope()
body = self.visit(ctx.expr(1))
return expr.Let(var, value, body)
def visitBinOp(self, ctx):
# type: (RelayParser.BinOpContext) -> expr.Call
"""Desugar binary operators."""
arg0, arg1 = self.visit_list(ctx.expr())
relay_op = BINARY_OPS.get(ctx.op.type)
if relay_op is None:
raise ParseError("Unimplemented binary op.")
return relay_op(arg0, arg1)
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
ident = ctx.ident().LOCAL_VAR()
if ident is None:
raise ParseError('Only local ids may be used in params.')
type_ = self.getType_(ctx.type_())
return self.mk_var(ident.getText()[1:], type_)
def visitVarList(self, ctx):
# type: (RelayParser.VarListContext) -> List[expr.Var]
return self.visit_list(ctx.var())
def mk_func(self, ctx):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self.enter_var_scope()
# Capture type params in params.
self.enter_type_param_scope()
var_list = self.visit(ctx.varList())
ret_type = self.getType_(ctx.type_())
type_params = list(self.exit_type_param_scope())
if type_params:
_, type_params = zip(*type_params)
body = self.visit(ctx.body())
self.exit_var_scope()
return expr.Function(var_list, body, ret_type, type_params) # type: ignore
def visitFunc(self, ctx):
# type: (RelayParser.FuncContext) -> expr.Function
return self.mk_func(ctx)
def visitDefn(self, ctx):
# type: (RelayParser.DefnContext) -> None
ident = ctx.ident().GLOBAL_VAR()
if ident is None:
raise ParseError('Only global ids may be used in `def`s.')
ident = expr.GlobalVar(ident.getText()[1:])
self.module[ident] = self.mk_func(ctx)
def visitCall(self, ctx):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs = self.visit_list(ctx.expr())
func = visited_exprs[0]
args = visited_exprs[1:]
return expr.Call(func, args, None, None)
def visitIfElse(self, ctx):
# type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch."""
cond = self.visit(ctx.expr())
self.enter_var_scope()
true_branch = self.visit(ctx.body(0))
self.exit_var_scope()
self.enter_var_scope()
false_branch = self.visit(ctx.body(1))
self.exit_var_scope()
return expr.If(cond, true_branch, false_branch)
# Types
# pylint: disable=unused-argument
def visitIncompleteType(self, ctx):
# type (RelayParser.IncompleteTypeContext) -> None:
return None
def visitIdentType(self, ctx):
# type: (RelayParser.IdentTypeContext) -> Union[ty.TensorType, str]
ident_type = ctx.CNAME().getText()
# look through all type prefixes for a match
for type_prefix in TYPE_PREFIXES:
if ident_type.startswith(type_prefix):
return ty.scalar_type(ident_type)
raise ParseError("Unknown builtin type: {}".format(ident_type))
# def visitCallType(self, ctx):
# # type: (RelayParser.CallTypeContext) -> Union[expr.Expr, ty.TensorType]
# ident_type = ctx.identType().CNAME().getText()
# args = self.visit_list(ctx.type_())
# if not args:
# raise ParseError("Type-level functions must have arguments!")
# func_type = TYPE_FUNCS.get(ident_type)(args)
# if func_type is None:
# raise ParseError("Unknown type-level function: `{}`".format(ident_type))
# else:
# return func_type
def visitParensShape(self, ctx):
# type: (RelayParser.ParensShapeContext) -> int
return self.visit(ctx.shape())
def visitShapeSeq(self, ctx):
# type: (RelayParser.ShapeSeqContext) -> List[int]
return self.visit_list(ctx.shape())
def visitTensorType(self, ctx):
# type: (RelayParser.TensorTypeContext) -> ty.TensorType
"""Create a simple tensor type. No generics."""
shape = self.visit(ctx.shapeSeq())
dtype = self.visit(ctx.type_())
if not isinstance(dtype, ty.TensorType):
raise ParseError("Expected dtype to be a Relay base type.")
dtype = dtype.dtype
return ty.TensorType(shape, dtype)
def visitTupleType(self, ctx):
# type: (RelayParser.TupleTypeContext) -> ty.TupleType
return ty.TupleType(self.visit_list(ctx.type_()))
def visitFuncType(self, ctx):
# type: (RelayParser.FuncTypeContext) -> ty.FuncType
types = self.visit_list(ctx.type_())
arg_types = types[:-1]
ret_type = types[-1]
return ty.FuncType(arg_types, ret_type, [], None)
def make_parser(data):
# type: (str) -> RelayParser
"""Construct a RelayParser a given data stream."""
input_stream = InputStream(data)
lexer = RelayLexer(input_stream)
token_stream = CommonTokenStream(lexer)
return RelayParser(token_stream)
def fromtext(data):
# type: (str) -> Union[expr.Expr, env.Environment]
"""Parse a Relay program."""
tree = make_parser(data).prog()
return ParseTreeToRelayIR().visit(tree)
......@@ -22,7 +22,7 @@ class Constant(Expr):
class Tuple(Expr):
fields = .. # type: List[Expr]
fields = ... # type: List[Expr]
def __init__(self, fields):
# type: (List[Expr]) -> None
......@@ -77,10 +77,10 @@ class Call(Expr):
"""A function call in Relay, see tvm/relay/expr.h for more details."""
op = ... # type: Expr
args = ... # type: List[Expr]
# todo(@jroesch): add attrs
# todo(@jroesch): add attrs. revise attrs type in __init__
def __init__(self, op, args, attrs, ty_args=None):
# type: (Expr, List[Expr], Optional[List[Type]]) -> None
def __init__(self, op, args, attrs=None, ty_args=None):
# type: (Expr, List[Expr], Optional[List[Any]], Optional[List[Type]]) -> None
if not ty_args:
ty_args = []
......
grammar Relay;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
LINE_COMMENT : '//' .*? '\n' -> skip ;
COMMENT : '/*' .*? '*/' -> skip ;
// operators
MUL: '*' ;
DIV: '/' ;
ADD: '+' ;
SUB: '-' ;
LT: '<' ;
GT: '>' ;
LE: '<=' ;
GE: '>=' ;
EQ: '==' ;
NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ;
MUT: 'mut' ;
BOOL_LIT
: 'True'
| 'False'
;
// non-negative floats
FLOAT
: INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| INT EXP // 1e10 3e4
;
// non-negative ints
INT: DIGIT+ ;
fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
fragment DIGIT: [0-9] ;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: (defn* | expr) EOF ;
// option: 'set' ident BOOL_LIT ;
expr
// operators
: '(' expr ')' # parens
| '-' expr # neg
| expr op=('*'|'/') expr # binOp
| expr op=('+'|'-') expr # binOp
| expr op=('<'|'>'|'<='|'>=') expr # binOp
| expr op=('=='|'!=') expr # binOp
// function definition and application
| expr '(' (expr (',' expr)*)? ')' # call
| func # funcExpr
// tuples and tensors
| '(' ')' # tuple
| '(' expr ',' ')' # tuple
| '(' expr (',' expr)+ ')' # tuple
| '[' (expr (',' expr)*)? ']' # tensor
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr # seq
| 'let' MUT? var '=' '{' expr '}' ';' expr # seq
// sugar for let %_ = expr; expr
| expr ';' expr # seq
// mutable update
// | ident '=' expr # writeRef
// | expr '^' # readRef
| ident # identExpr
| scalar # scalarExpr
// | expr '.' INT # project
// | 'debug' # debug
;
func: 'fn' varList ('->' type_)? body ;
defn: 'def' ident varList ('->' type_)? body ;
varList: '(' (var (',' var)*)? ')' ;
var: ident (':' type_)? ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
// relations: 'where' relation (',' relation)* ;
// relation: ident '(' (type_ (',' type_)*)? ')' ;
type_
: '(' ')' # tupleType
| '(' type_ ',' ')' # tupleType
| '(' type_ (',' type_)+ ')' # tupleType
| identType # identTypeType
| 'Tensor' '[' shapeSeq ',' type_ ']' # tensorType
// currently unused
// | identType '[' (type_ (',' type_)*)? ']' # callType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| INT # intType
;
shapeSeq
: '(' ')'
| '(' shape ',' ')'
| '(' shape (',' shape)+ ')'
;
shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
| INT # intShape
;
identType: CNAME ;
// Int8, Int16, Int32, Int64
// UInt8, UInt16, UInt32, UInt64
// Float16, Float32, Float64
// Bool
body: '{' expr '}' ;
scalar
: FLOAT # scalarFloat
| INT # scalarInt
| BOOL_LIT # scalarBool
;
ident
: opIdent
| GLOBAL_VAR
| LOCAL_VAR
;
"""A parser for Relay's text format."""
from __future__ import absolute_import
def enabled():
"""Is the parser enabled/Can we import the parser?"""
try:
# pylint: disable=unused-variable
from tvm.relay import _parser
return True
# pylint: disable=broad-except
except Exception:
return False
def fromtext(data):
"""Parse a Relay program."""
from tvm.relay import _parser
return _parser.fromtext(data)
......@@ -156,7 +156,7 @@ class FuncType(Type):
class IncompleteType(Type):
"""An incomplete type."""
def __init__(self, kind):
def __init__(self, kind=Kind.Type):
self.__init_handle_by_constructor__(_make.IncompleteType, kind)
@register_relay_node
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment