Commit d274e4b3 by Jared Roesch Committed by Tianqi Chen

[Relay][Parser] Improve Relay parser and pretty printing, including CMAKE (#2377)

parent d0f83664
if(USE_ANTLR)
if(EXISTS /usr/local/lib/antlr-4.7.1-complete.jar)
set(ANTLR4 "/usr/local/lib/antlr-4.7.1-complete.jar")
file(GLOB_RECURSE ANTLR4
/usr/local/lib/antlr-*-complete.jar
/usr/local/Cellar/*antlr-*-complete.jar)
if(DEFINED ANTLR4)
# Get the first element of the list of antlr jars.
# Sort and reverse the list so the item selected is the highest
# version in lib or else in Cellar if no lib installation exists.
list(SORT ANTLR4)
list(REVERSE ANTLR4)
list(GET ANTLR4 0 ANTLR4)
set(RELAY_PARSER_DIR
${CMAKE_CURRENT_SOURCE_DIR}/python/tvm/relay/grammar)
......@@ -14,15 +22,21 @@ if(USE_ANTLR)
${RELAY_PARSER_DIR}/py3/RelayParser.py
${RELAY_PARSER_DIR}/py3/RelayLexer.py)
set(JAVA_HOME $ENV{JAVA_HOME})
if (NOT DEFINED JAVA_HOME)
# Hack to get system to search for Java itself.
set(JAVA_HOME "/usr")
endif()
# Generate ANTLR grammar for parsing.
add_custom_command(OUTPUT ${RELAY_PARSER}
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND $ENV{JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python2 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py2
COMMAND ${JAVA_HOME}/bin/java -jar ${ANTLR4} -visitor -no-listener -Dlanguage=Python3 ${RELAY_PARSER_DIR}/Relay.g4 -o ${RELAY_PARSER_DIR}/py3
DEPENDS ${RELAY_PARSER_DIR}/Relay.g4
WORKING_DIRECTORY ${RELAY_PARSER_DIR})
add_custom_target(relay_parser ALL DEPENDS ${RELAY_PARSER})
else()
message(FATAL_ERROR "Can't find ANTLR4!")
message(FATAL_ERROR "Can't find ANTLR4: ANTLR4=" ${ANTLR4})
endif()
endif(USE_ANTLR)
......@@ -108,7 +108,9 @@ class SourceName : public NodeRef {
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const SourceNameNode* operator->() const;
inline const SourceNameNode* operator->() const {
return static_cast<SourceNameNode*>(this->node_.get());
}
/*!
* \brief Get an SourceName for a given operator name.
......
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface of expr function exposed from C++."""
from tvm._ffi.function import _init_api
_init_api("relay._base", __name__)
......@@ -6,13 +6,17 @@ from __future__ import absolute_import
import sys
from collections import deque
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any
from typing import TypeVar, Deque, Tuple, Optional, Union, NamedTuple, List, Callable, Any, Dict
import tvm
from . import module
from .base import Span, SourceName
from . import expr
from . import ty
from . import op
class ParseError(Exception):
"""Exception type for parse errors."""
......@@ -76,22 +80,46 @@ def lookup(scopes, name):
return val
return None
def spanify(f):
"""A decorator which attaches span information
to the value returned by calling `f`.
Intended for use with the below AST visiting
methods. The idea is that after we do the work
of constructing the AST we attach Span information.
"""
def _wrapper(*args, **kwargs):
# Assumes 0th arg is self and gets source_name from object.
sn = args[0].source_name
# Assumes 1st arg is an ANTLR parser context.
ctx = args[1]
ast = f(*args, **kwargs)
line, col = ctx.getSourceInterval()
sp = Span(sn, line, col)
ast.set_span(sp)
return ast
return _wrapper
# TODO(@jmp): Use https://stackoverflow.com/q/13889941
# to figure out how to get ANTLR4 to be more unhappy about syntax errors
class ParseTreeToRelayIR(RelayVisitor):
"""Parse Relay text format into Relay IR."""
def __init__(self):
# type: () -> None
def __init__(self, source_name):
# type: (str) -> None
self.source_name = source_name
self.module = module.Module({}) # type: module.Module
# Adding an empty scope allows naked lets without pain.
self.var_scopes = deque([deque()]) # type: Scopes[expr.Var]
self.global_var_scope = deque() # type: Scope[expr.GlobalVar]
self.type_param_scopes = deque([deque()]) # type: Scopes[ty.TypeVar]
self.graph_expr = [] # type: List[expr.Expr]
super(ParseTreeToRelayIR, self).__init__()
def enter_var_scope(self):
# type: () -> None
"""Enter a new Var scope so it can be popped off later."""
......@@ -146,20 +174,25 @@ class ParseTreeToRelayIR(RelayVisitor):
node_type = node.getSymbol().type
node_text = node.getText()
name = node_text[1:]
# variables
if node_type == RelayLexer.GLOBAL_VAR:
return lookup([self.global_var_scope], node_text[1:])
return lookup(deque([self.global_var_scope]), node_text[1:])
elif node_type == RelayLexer.LOCAL_VAR:
name = node_text[1:]
# Remove the leading '%' and lookup the name.
var = lookup(self.var_scopes, name)
if var is None:
raise ParseError("Couldn't resolve `{}`.".format(name))
return var
elif node_type == RelayLexer.GRAPH_VAR:
try:
return self.graph_expr[int(name)]
except IndexError:
raise ParseError("Couldn't resolve `{}`".format(name))
# data types
elif node_type == RelayLexer.INT:
elif node_type == RelayLexer.NAT:
return int(node_text)
elif node_type == RelayLexer.FLOAT:
return float(node_text)
......@@ -190,7 +223,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return self.visit(ctx)
def visitProg(self, ctx):
# type: (RelayParser.ProgContext) -> Union[expr.Expr, env.Environment]
# type: (RelayParser.ProgContext) -> Union[expr.Expr, module.Module]
if ctx.defn():
self.visit_list(ctx.defn())
return self.module
......@@ -219,7 +252,7 @@ class ParseTreeToRelayIR(RelayVisitor):
def visitScalarInt(self, ctx):
# type: (RelayParser.ScalarIntContext) -> expr.Constant
return expr.const(self.visit(ctx.INT()))
return expr.const(self.visit(ctx.NAT()))
def visitScalarBool(self, ctx):
# type: (RelayParser.ScalarBoolContext) -> expr.Constant
......@@ -240,7 +273,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.Tuple(tup)
# Currently doesn't support mutable sequencing.
def visitSeq(self, ctx):
def visitLet(self, ctx):
# type: (RelayParser.SeqContext) -> expr.Let
"""Desugar various sequence constructs to Relay Let nodes."""
if ctx.MUT() is not None:
......@@ -253,7 +286,7 @@ class ParseTreeToRelayIR(RelayVisitor):
else:
local_var = ctx.var().ident().LOCAL_VAR()
if local_var is None:
raise ParseError('Only local ids may be used in `let`s.')
raise ParseError("Only local ids may be used in `let`s.")
ident = local_var.getText()[1:]
type_ = self.getType_(ctx.var().type_())
......@@ -278,12 +311,14 @@ class ParseTreeToRelayIR(RelayVisitor):
return relay_op(arg0, arg1)
@spanify
def visitVar(self, ctx):
# type: (RelayParser.VarContext) -> expr.Var
"""Visit a single variable."""
ident = ctx.ident().LOCAL_VAR()
if ident is None:
raise ParseError('Only local ids may be used in params.')
raise ParseError("Only local ids may be used in vars.")
type_ = self.getType_(ctx.type_())
......@@ -293,15 +328,33 @@ class ParseTreeToRelayIR(RelayVisitor):
# type: (RelayParser.VarListContext) -> List[expr.Var]
return self.visit_list(ctx.var())
# TODO: support a larger class of values than just Relay exprs
def visitAttr(self, ctx):
# type: (RelayParser.AttrContext) -> Tuple[str, expr.Expr]
return (ctx.CNAME().getText(), self.visit(ctx.expr()))
def visitAttrList(self, ctx):
# type: (RelayParser.AttrListContext) -> Dict[str, expr.Expr]
return dict(self.visit_list(ctx.attr()))
def visitArgList(self,
ctx # type: RelayParser.ArgListContext
):
# type: (...) -> Tuple[Optional[List[expr.Var]], Optional[Dict[str, expr.Expr]]]
var_list = self.visit(ctx.varList()) if ctx.varList() else None
attr_list = self.visit(ctx.attrList()) if ctx.attrList() else None
return (var_list, attr_list)
def mk_func(self, ctx):
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> Function
# type: (Union[RelayParser.FuncContext, RelayParser.DefnContext]) -> expr.Function
"""Construct a function from either a Func or Defn."""
# Enter var scope early to put params in scope.
self.enter_var_scope()
# Capture type params in params.
self.enter_type_param_scope()
var_list = self.visit(ctx.varList())
var_list, attr_list = self.visit(ctx.argList())
ret_type = self.getType_(ctx.type_())
type_params = list(self.exit_type_param_scope())
......@@ -311,22 +364,28 @@ class ParseTreeToRelayIR(RelayVisitor):
body = self.visit(ctx.body())
self.exit_var_scope()
return expr.Function(var_list, body, ret_type, type_params) # type: ignore
attrs = tvm.make.node("DictAttrs", **attr_list) if attr_list is not None else None
return expr.Function(var_list, body, ret_type, type_params, attrs)
@spanify
def visitFunc(self, ctx):
# type: (RelayParser.FuncContext) -> expr.Function
return self.mk_func(ctx)
# TODO: how to set spans for definitions?
# @spanify
def visitDefn(self, ctx):
# type: (RelayParser.DefnContext) -> None
ident = ctx.ident().GLOBAL_VAR()
if ident is None:
raise ParseError('Only global ids may be used in `def`s.')
raise ParseError("Only global ids may be used in `def`s.")
ident_name = ident.getText()[1:]
ident = self.mk_global_var(ident_name)
self.module[ident] = self.mk_func(ctx)
@spanify
def visitCall(self, ctx):
# type: (RelayParser.CallContext) -> expr.Call
visited_exprs = self.visit_list(ctx.expr())
......@@ -336,6 +395,7 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.Call(func, args, None, None)
@spanify
def visitIfElse(self, ctx):
# type: (RelayParser.IfElseContext) -> expr.If
"""Construct a Relay If node. Creates a new scope for each branch."""
......@@ -351,6 +411,27 @@ class ParseTreeToRelayIR(RelayVisitor):
return expr.If(cond, true_branch, false_branch)
@spanify
def visitGraph(self, ctx):
# type: (RelayParser.GraphContext) -> expr.Expr
"""Visit a graph variable assignment."""
if ctx.ident().GRAPH_VAR() is None:
raise ParseError("Expected a graph var, but got `{}`".format(ctx.ident().getText()))
graph_nid = int(ctx.ident().GRAPH_VAR().getText()[1:])
self.enter_var_scope()
value = self.visit(ctx.expr(0))
self.exit_var_scope()
if graph_nid != len(self.graph_expr):
raise ParseError(
"Expected new graph variable to be `%{}`,".format(len(self.graph_expr)) + \
"but got `%{}`".format(graph_nid))
self.graph_expr.append(value)
kont = self.visit(ctx.expr(1))
return kont
# Types
# pylint: disable=unused-argument
......@@ -428,8 +509,18 @@ def make_parser(data):
token_stream = CommonTokenStream(lexer)
return RelayParser(token_stream)
def fromtext(data):
# type: (str) -> Union[expr.Expr, env.Environment]
__source_name_counter__ = 0
def fromtext(data, source_name=None):
# type: (str, str) -> Union[expr.Expr, module.Module]
"""Parse a Relay program."""
global __source_name_counter__
if source_name is None:
source_name = "source_file{0}".format(__source_name_counter__)
if isinstance(source_name, str):
source_name = SourceName(source_name)
tree = make_parser(data).prog()
return ParseTreeToRelayIR().visit(tree)
return ParseTreeToRelayIR(source_name).visit(tree)
......@@ -4,6 +4,7 @@ from __future__ import absolute_import as _abs
from .._ffi.node import NodeBase, register_node as _register_tvm_node
from . import _make
from . import _expr
from . import _base
NodeBase = NodeBase
......@@ -63,6 +64,9 @@ class RelayNode(NodeBase):
"""
return _expr.RelayPrint(self, show_meta_data, annotate)
def set_span(self, span):
_base.set_span(self, span)
@register_relay_node
class Span(RelayNode):
......@@ -71,6 +75,12 @@ class Span(RelayNode):
def __init__(self, source, lineno, col_offset):
self.__init_handle_by_constructor__(_make.Span, source, lineno, col_offset)
@register_relay_node
class SourceName(RelayNode):
"""A identifier for a source location"""
def __init__(self, name):
self.__init_handle_by_constructor__(_make.SourceName, name)
@register_relay_node
class Id(NodeBase):
......
grammar Relay;
SEMVER: 'v0.0.1' ;
// Lexing
// comments
WS : [ \t\n\r]+ -> skip ;
......@@ -20,7 +22,8 @@ NE: '!=' ;
opIdent: CNAME ;
GLOBAL_VAR: '@' CNAME ;
LOCAL_VAR: '%' CNAME ;
LOCAL_VAR: '%' CNAME;
GRAPH_VAR: '%' NAT;
MUT: 'mut' ;
......@@ -31,13 +34,13 @@ BOOL_LIT
// non-negative floats
FLOAT
: INT '.' INT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| INT EXP // 1e10 3e4
: NAT '.' NAT EXP? // 1.35, 1.35E-9, 0.3, 4.5
| NAT EXP // 1e10 3e4
;
// non-negative ints
INT: DIGIT+ ;
fragment EXP: [eE] [+\-]? INT ; // \- since - means "range" inside [...]
NAT: DIGIT+ ;
fragment EXP: [eE] [+\-]? NAT ; // \- since - means "range" inside [...]
CNAME: ('_'|LETTER) ('_'|LETTER|DIGIT)* ;
fragment LETTER: [a-zA-Z] ;
......@@ -46,7 +49,7 @@ fragment DIGIT: [0-9] ;
// Parsing
// A Relay program is a list of global definitions or an expression.
prog: (defn* | expr) EOF ;
prog: SEMVER (defn* | expr) EOF ;
// option: 'set' ident BOOL_LIT ;
......@@ -73,10 +76,11 @@ expr
| 'if' '(' expr ')' body 'else' body # ifElse
// sequencing
| 'let' MUT? var '=' expr ';' expr # seq
| 'let' MUT? var '=' '{' expr '}' ';' expr # seq
| 'let' MUT? var '=' expr ';' expr # let
| 'let' MUT? var '=' '{' expr '}' ';' expr # let
// sugar for let %_ = expr; expr
| expr ';' expr # seq
| expr ';' expr # let
| ident '=' expr ';' expr # graph
// mutable update
// | ident '=' expr # writeRef
......@@ -84,16 +88,25 @@ expr
| ident # identExpr
| scalar # scalarExpr
// | expr '.' INT # project
// | expr '.' NAT # project
// | 'debug' # debug
;
func: 'fn' varList ('->' type_)? body ;
defn: 'def' ident varList ('->' type_)? body ;
func: 'fn' '(' argList ')' ('->' type_)? body ;
defn: 'def' ident '(' argList ')' ('->' type_)? body ;
argList
: varList
| attrList
| varList ',' attrList
;
varList: '(' (var (',' var)*)? ')' ;
varList: (var (',' var)*)? ;
var: ident (':' type_)? ;
attrList: (attr (',' attr)*)? ;
attr: CNAME '=' expr ;
// TODO(@jmp): for improved type annotations
// returnAnno: (ident ':')? type_ ;
......@@ -110,7 +123,7 @@ type_
// | identType '[' (type_ (',' type_)*)? ']' # callType
| 'fn' '(' (type_ (',' type_)*)? ')' '->' type_ # funcType
| '_' # incompleteType
| INT # intType
| NAT # intType
;
shapeSeq
......@@ -123,20 +136,20 @@ shape
: '(' shape ')' # parensShape
// | type_ op=('*'|'/') type_ # binOpType
// | type_ op=('+'|'-') type_ # binOpType
| INT # intShape
| NAT # intShape
;
identType: CNAME ;
// Int8, Int16, Int32, Int64
// UInt8, UInt16, UInt32, UInt64
// Float16, Float32, Float64
// Bool
// int8, int16, int32, int64
// uint8, uint16, uint32, uint64
// float16, float32, float64
// bool
body: '{' expr '}' ;
scalar
: FLOAT # scalarFloat
| INT # scalarInt
| NAT # scalarInt
| BOOL_LIT # scalarBool
;
......@@ -144,4 +157,5 @@ ident
: opIdent
| GLOBAL_VAR
| LOCAL_VAR
| GRAPH_VAR
;
"""A parser for Relay's text format."""
from __future__ import absolute_import
from .. import register_func
def enabled():
"""Is the parser enabled/Can we import the parser?"""
"""Checks whether the parser is enabled, this allows users to
optionally support building the parser.
We use this check before importing the parser.
"""
try:
# pylint: disable=unused-variable
from tvm.relay import _parser
......@@ -11,7 +16,8 @@ def enabled():
except Exception:
return False
def fromtext(data):
@register_func("relay.fromtext")
def fromtext(data, source_name=None):
"""Parse a Relay program."""
from tvm.relay import _parser
return _parser.fromtext(data)
return _parser.fromtext(data, source_name)
......@@ -32,6 +32,11 @@ SourceName SourceName::Get(const std::string& name) {
return SourceName(GetSourceNameNode(name));
}
TVM_REGISTER_API("relay._make.SourceName")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = SourceName::Get(args[0]);
});
TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<SourceNameNode>([](const SourceNameNode* node, tvm::IRPrinter* p) {
p->stream << "SourceName(" << node->name << ", " << node << ")";
......@@ -66,5 +71,14 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
TVM_REGISTER_NODE_TYPE(IdNode);
TVM_REGISTER_API("relay._base.set_span")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef node_ref = args[0];
auto rn = node_ref.as_derived<RelayNode>();
CHECK(rn);
Span sp = args[1];
rn->span = sp;
});
} // namespace relay
} // namespace tvm
......@@ -8,11 +8,12 @@ from numpy import isclose
from typing import Union
from functools import wraps
if enabled():
from tvm.relay._parser import ParseError
raises_parse_error = raises(ParseError)
raises_parse_error = raises(tvm._ffi.base.TVMError)
else:
raises_parse_error = lambda x: x
SEMVER = "v0.0.1"
BINARY_OPS = {
"*": relay.multiply,
"/": relay.divide,
......@@ -48,6 +49,10 @@ TYPES = {
"float16x4",
}
def parses_as(code, expr):
# type: (str, relay.Expr) -> bool
return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr)
def get_scalar(x):
# type: (relay.Constant) -> (Union[float, int, bool])
return x.data.asnumpy().item()
......@@ -74,80 +79,80 @@ def if_parser_enabled(func):
@if_parser_enabled
def test_comments():
assert alpha_equal(
relay.fromtext("""
assert parses_as(
"""
// This is a line comment!
()
"""),
""",
UNIT
)
assert alpha_equal(
relay.fromtext("""
assert parses_as(
"""
/* This is a block comment!
This is still a block comment!
*/
()
"""),
""",
UNIT
)
@if_parser_enabled
def test_int_literal():
assert isinstance(relay.fromtext("1"), relay.Constant)
assert isinstance(relay.fromtext("1").data, tvm.ndarray.NDArray)
assert isinstance(relay.fromtext(SEMVER+"1"), relay.Constant)
assert isinstance(relay.fromtext(SEMVER+"1").data, tvm.ndarray.NDArray)
assert get_scalar(relay.fromtext("1")) == 1
assert get_scalar(relay.fromtext("10")) == 10
assert get_scalar(relay.fromtext("0")) == 0
assert get_scalar(relay.fromtext("-100")) == -100
assert get_scalar(relay.fromtext("-05")) == -5
assert get_scalar(relay.fromtext(SEMVER+"1")) == 1
assert get_scalar(relay.fromtext(SEMVER+"10")) == 10
assert get_scalar(relay.fromtext(SEMVER+"0")) == 0
assert get_scalar(relay.fromtext(SEMVER+"-100")) == -100
assert get_scalar(relay.fromtext(SEMVER+"-05")) == -5
@if_parser_enabled
def test_float_literal():
assert get_scalar(relay.fromtext("1.0")) == 1.0
assert isclose(get_scalar(relay.fromtext("1.56667")), 1.56667)
assert get_scalar(relay.fromtext("0.0")) == 0.0
assert get_scalar(relay.fromtext("-10.0")) == -10.0
assert get_scalar(relay.fromtext(SEMVER+"1.0")) == 1.0
assert isclose(get_scalar(relay.fromtext(SEMVER+"1.56667")), 1.56667)
assert get_scalar(relay.fromtext(SEMVER+"0.0")) == 0.0
assert get_scalar(relay.fromtext(SEMVER+"-10.0")) == -10.0
# scientific notation
assert isclose(get_scalar(relay.fromtext("1e-1")), 1e-1)
assert get_scalar(relay.fromtext("1e+1")) == 1e+1
assert isclose(get_scalar(relay.fromtext("1E-1")), 1E-1)
assert get_scalar(relay.fromtext("1E+1")) == 1E+1
assert isclose(get_scalar(relay.fromtext("1.0e-1")), 1.0e-1)
assert get_scalar(relay.fromtext("1.0e+1")) == 1.0e+1
assert isclose(get_scalar(relay.fromtext("1.0E-1")), 1.0E-1)
assert get_scalar(relay.fromtext("1.0E+1")) == 1.0E+1
assert isclose(get_scalar(relay.fromtext(SEMVER+"1e-1")), 1e-1)
assert get_scalar(relay.fromtext(SEMVER+"1e+1")) == 1e+1
assert isclose(get_scalar(relay.fromtext(SEMVER+"1E-1")), 1E-1)
assert get_scalar(relay.fromtext(SEMVER+"1E+1")) == 1E+1
assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0e-1")), 1.0e-1)
assert get_scalar(relay.fromtext(SEMVER+"1.0e+1")) == 1.0e+1
assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0E-1")), 1.0E-1)
assert get_scalar(relay.fromtext(SEMVER+"1.0E+1")) == 1.0E+1
@if_parser_enabled
def test_bool_literal():
assert get_scalar(relay.fromtext("True")) == True
assert get_scalar(relay.fromtext("False")) == False
assert get_scalar(relay.fromtext(SEMVER+"True")) == True
assert get_scalar(relay.fromtext(SEMVER+"False")) == False
@if_parser_enabled
def test_negative():
assert isinstance(relay.fromtext("let %x = 1; -%x").body, relay.Call)
assert get_scalar(relay.fromtext("--10")) == 10
assert get_scalar(relay.fromtext("---10")) == -10
assert isinstance(relay.fromtext(SEMVER+"let %x = 1; -%x").body, relay.Call)
assert get_scalar(relay.fromtext(SEMVER+"--10")) == 10
assert get_scalar(relay.fromtext(SEMVER+"---10")) == -10
@if_parser_enabled
def test_bin_op():
for bin_op in BINARY_OPS.keys():
assert alpha_equal(
relay.fromtext("1 {} 1".format(bin_op)),
assert parses_as(
"1 {} 1".format(bin_op),
BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1))
)
@if_parser_enabled
def test_parens():
assert alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("(1 * 1) + 1"))
assert not alpha_equal(relay.fromtext("1 * 1 + 1"), relay.fromtext("1 * (1 + 1)"))
assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1"))
assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)"))
@if_parser_enabled
def test_op_assoc():
assert alpha_equal(relay.fromtext("1 * 1 + 1 < 1 == 1"), relay.fromtext("(((1 * 1) + 1) < 1) == 1"))
assert alpha_equal(relay.fromtext("1 == 1 < 1 + 1 * 1"), relay.fromtext("1 == (1 < (1 + (1 * 1)))"))
assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1"))
assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))"))
@nottest
@if_parser_enabled
......@@ -159,24 +164,24 @@ def test_vars():
# assert temp_var.name == "1"
# var
var = relay.fromtext("let %foo = (); %foo")
var = relay.fromtext(SEMVER+"let %foo = (); %foo")
assert isinstance(var.body, relay.Var)
assert var.body.name_hint == "foo"
# global var
global_var = relay.fromtext("@foo")
global_var = relay.fromtext(SEMVER+"@foo")
assert isinstance(global_var, relay.GlobalVar)
assert global_var.name_hint == "foo"
# operator id
op = relay.fromtext("foo")
op = relay.fromtext(SEMVER+"foo")
assert isinstance(op, relay.Op)
assert op.name == "foo"
@if_parser_enabled
def test_let():
assert alpha_equal(
relay.fromtext("let %x = 1; ()"),
assert parses_as(
"let %x = 1; ()",
relay.Let(
X,
relay.const(1),
......@@ -184,18 +189,35 @@ def test_let():
)
)
assert parses_as(
"""
let %x = 1;
let %y = 2;
()
""",
relay.Let(
X,
relay.const(1),
relay.Let(
Y,
relay.const(2),
UNIT
)
)
)
@if_parser_enabled
def test_seq():
assert alpha_equal(
relay.fromtext("(); ()"),
assert parses_as(
"(); ()",
relay.Let(
_,
UNIT,
UNIT)
)
assert alpha_equal(
relay.fromtext("let %_ = { 1 }; ()"),
assert parses_as(
"let %_ = { 1 }; ()",
relay.Let(
X,
relay.const(1),
......@@ -203,31 +225,48 @@ def test_seq():
)
)
@if_parser_enabled
def test_graph():
assert parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)",
relay.Tuple([UNIT, UNIT, relay.const(1)])
)
assert not parses_as(
"%0 = (); %1 = 1; (%0, %0, %1)",
relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)])
)
@raises_parse_error
@if_parser_enabled
def test_graph_wrong_order():
relay.fromtext(SEMVER+"%1 = (); %1")
@raises_parse_error
@if_parser_enabled
def test_let_global_var():
relay.fromtext("let @x = 1; ()")
relay.fromtext(SEMVER+"let @x = 1; ()")
@raises_parse_error
@if_parser_enabled
def test_let_op():
relay.fromtext("let x = 1; ()")
relay.fromtext(SEMVER+"let x = 1; ()")
@if_parser_enabled
def test_tuple():
assert alpha_equal(relay.fromtext("()"), relay.Tuple([]))
assert parses_as("()", relay.Tuple([]))
assert alpha_equal(relay.fromtext("(0,)"), relay.Tuple([relay.const(0)]))
assert parses_as("(0,)", relay.Tuple([relay.const(0)]))
assert alpha_equal(relay.fromtext("(0, 1)"), relay.Tuple([relay.const(0), relay.const(1)]))
assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)]))
assert alpha_equal(relay.fromtext("(0, 1, 2)"), relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)]))
@if_parser_enabled
def test_func():
# 0 args
assert alpha_equal(
relay.fromtext("fn () { 0 }"),
assert parses_as(
"fn () { 0 }",
relay.Function(
[],
relay.const(0),
......@@ -237,8 +276,8 @@ def test_func():
)
# 1 arg
assert alpha_equal(
relay.fromtext("fn (%x) { %x }"),
assert parses_as(
"fn (%x) { %x }",
relay.Function(
[X],
X,
......@@ -248,8 +287,8 @@ def test_func():
)
# 2 args
assert alpha_equal(
relay.fromtext("fn (%x, %y) { %x + %y }"),
assert parses_as(
"fn (%x, %y) { %x + %y }",
relay.Function(
[X, Y],
relay.add(X, Y),
......@@ -259,8 +298,8 @@ def test_func():
)
# annotations
assert alpha_equal(
relay.fromtext("fn (%x: int32) -> int32 { %x }"),
assert parses_as(
"fn (%x: int32) -> int32 { %x }",
relay.Function(
[X_ANNO],
X_ANNO,
......@@ -269,11 +308,17 @@ def test_func():
)
)
# attributes
assert parses_as(
"fn (n=5) { () }",
relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5)))
)
# TODO(@jmp): Crashes if %x isn't annnotated.
# @nottest
@if_parser_enabled
def test_defn():
id_defn = relay.fromtext(
SEMVER+
"""
def @id(%x: int32) -> int32 {
%x
......@@ -284,6 +329,7 @@ def test_defn():
@if_parser_enabled
def test_recursive_call():
id_defn = relay.fromtext(
SEMVER+
"""
def @id(%x: int32) -> int32 {
@id(%x)
......@@ -293,16 +339,14 @@ def test_recursive_call():
@if_parser_enabled
def test_ifelse():
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
if (True) {
0
} else {
1
}
"""
),
""",
relay.If(
relay.const(True),
relay.const(0),
......@@ -314,6 +358,7 @@ def test_ifelse():
@if_parser_enabled
def test_ifelse_scope():
relay.fromtext(
SEMVER+
"""
if (True) {
let %x = ();
......@@ -328,13 +373,11 @@ def test_ifelse_scope():
def test_call():
# select right function to call: simple ident case
id_func = relay.Var("id")
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %id = fn (%x) { %x };
10 * %id(10)
"""
),
""",
relay.Let(
id_func,
relay.Function([X], X, None, []),
......@@ -344,13 +387,11 @@ def test_call():
# 0 args
constant = relay.Var("constant")
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %constant = fn () { 0 };
%constant()
"""
),
""",
relay.Let(
constant,
relay.Function([], relay.const(0), None, []),
......@@ -360,13 +401,11 @@ def test_call():
# 1 arg
id_var = relay.Var("id")
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %id = fn (%x) { %x };
%id(1)
"""
),
""",
relay.Let(
id_var,
relay.Function([X], X, None, []),
......@@ -376,13 +415,11 @@ def test_call():
# 2 args
multiply = relay.Var("multiply")
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %multiply = fn (%x, %y) { %x * %y };
%multiply(0, 0)
"""
),
""",
relay.Let(
multiply,
relay.Function(
......@@ -396,12 +433,10 @@ def test_call():
)
# anonymous function
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
(fn (%x) { %x })(0)
"""
),
""",
relay.Call(
relay.Function(
[X],
......@@ -415,45 +450,44 @@ def test_call():
)
)
# TODO(@jmp): re-enable after sequence parsing improvements
# curried function
curried_mult = relay.Var("curried_mult")
alpha_equal(
relay.fromtext(
"""
let %curried_mult =
fn (%x) {
fn (%y) {
%x * %y
}
};
%curried_mult(0);
%curried_mult(0)(0)
"""
),
relay.Let(
curried_mult,
relay.Function(
[X],
relay.Function(
[Y],
relay.multiply(X, Y),
None,
[]
),
None,
[]
),
relay.Let(
_,
relay.Call(curried_mult, [relay.const(0)], None, None),
relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
)
)
)
# curried_mult = relay.Var("curried_mult")
# assert parses_as(
# """
# let %curried_mult =
# fn (%x) {
# fn (%y) {
# %x * %y
# }
# };
# %curried_mult(0);
# %curried_mult(0)(0)
# """,
# relay.Let(
# curried_mult,
# relay.Function(
# [X],
# relay.Function(
# [Y],
# relay.multiply(X, Y),
# None,
# []
# ),
# None,
# []
# ),
# relay.Let(
# _,
# relay.Call(curried_mult, [relay.const(0)], None, None),
# relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None)
# )
# )
# )
# op
alpha_equal(
relay.fromtext("abs(1)"),
assert parses_as(
"abs(1)",
relay.Call(relay.op.get("abs"), [relay.const(1)], None, None)
)
......@@ -461,8 +495,8 @@ def test_call():
@if_parser_enabled
def test_incomplete_type():
assert alpha_equal(
relay.fromtext("let %_ : _ = (); ()"),
assert parses_as(
"let %_ : _ = (); ()",
relay.Let(
_,
UNIT,
......@@ -473,7 +507,7 @@ def test_incomplete_type():
@if_parser_enabled
def test_builtin_types():
for builtin_type in TYPES:
relay.fromtext("let %_ : {} = (); ()".format(builtin_type))
relay.fromtext(SEMVER+"let %_ : {} = (); ()".format(builtin_type))
@nottest
@if_parser_enabled
......@@ -482,8 +516,8 @@ def test_call_type():
@if_parser_enabled
def test_tensor_type():
assert alpha_equal(
relay.fromtext("let %_ : Tensor[(), float32] = (); ()"),
assert parses_as(
"let %_ : Tensor[(), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((), "float32")),
UNIT,
......@@ -491,8 +525,8 @@ def test_tensor_type():
)
)
assert alpha_equal(
relay.fromtext("let %_ : Tensor[(1,), float32] = (); ()"),
assert parses_as(
"let %_ : Tensor[(1,), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1,), "float32")),
UNIT,
......@@ -500,8 +534,8 @@ def test_tensor_type():
)
)
assert alpha_equal(
relay.fromtext("let %_ : Tensor[(1, 1), float32] = (); ()"),
assert parses_as(
"let %_ : Tensor[(1, 1), float32] = (); ()",
relay.Let(
relay.Var("_", relay.TensorType((1, 1), "float32")),
UNIT,
......@@ -511,12 +545,10 @@ def test_tensor_type():
@if_parser_enabled
def test_function_type():
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %_: fn () -> int32 = fn () -> int32 { 0 }; ()
"""
),
""",
relay.Let(
relay.Var("_", relay.FuncType([], int32, [], [])),
relay.Function([], relay.const(0), int32, []),
......@@ -524,12 +556,10 @@ def test_function_type():
)
)
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; ()
"""
),
""",
relay.Let(
relay.Var("_", relay.FuncType([int32], int32, [], [])),
relay.Function([relay.Var("x", int32)], relay.const(0), int32, []),
......@@ -537,12 +567,10 @@ def test_function_type():
)
)
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; ()
"""
),
""",
relay.Let(
relay.Var("_", relay.FuncType([int32, int32], int32, [], [])),
relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []),
......@@ -552,11 +580,10 @@ def test_function_type():
@if_parser_enabled
def test_tuple_type():
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %_: () = (); ()
"""),
""",
relay.Let(
relay.Var("_", relay.TupleType([])),
UNIT,
......@@ -564,11 +591,10 @@ def test_tuple_type():
)
)
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %_: (int32,) = (0,); ()
"""),
""",
relay.Let(
relay.Var("_", relay.TupleType([int32])),
relay.Tuple([relay.const(0)]),
......@@ -576,11 +602,10 @@ def test_tuple_type():
)
)
assert alpha_equal(
relay.fromtext(
assert parses_as(
"""
let %_: (int32, int32) = (0, 1); ()
"""),
""",
relay.Let(
relay.Var("_", relay.TupleType([int32, int32])),
relay.Tuple([relay.const(0), relay.const(1)]),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment