import tvm from tvm import relay from tvm.relay.parser import enabled from tvm.relay.ir_pass import alpha_equal from nose import SkipTest from nose.tools import nottest, raises from numpy import isclose from typing import Union from functools import wraps if enabled(): raises_parse_error = raises(tvm._ffi.base.TVMError) else: raises_parse_error = lambda x: x SEMVER = "v0.0.1" BINARY_OPS = { "*": relay.multiply, "/": relay.divide, "+": relay.add, "-": relay.subtract, "<": relay.less, ">": relay.greater, "<=": relay.less_equal, ">=": relay.greater_equal, "==": relay.equal, "!=": relay.not_equal, } TYPES = { "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64", "float16", "float32", "float64", "bool", "int8x4", "uint1x4", "float16x4", } def parses_as(code, expr): # type: (str, relay.Expr) -> bool return alpha_equal(relay.fromtext(SEMVER + "\n" + code), expr) def get_scalar(x): # type: (relay.Constant) -> (Union[float, int, bool]) return x.data.asnumpy().item() int32 = relay.scalar_type("int32") _ = relay.Var("_") X = relay.Var("x") Y = relay.Var("y") X_ANNO = relay.Var("x", int32) Y_ANNO = relay.Var("y", int32) UNIT = relay.Tuple([]) # decorator to determine if parser is enabled def if_parser_enabled(func): # https://stackoverflow.com/q/7727678 @wraps(func) def wrapper(): if not enabled(): raise SkipTest("ANTLR is not installed!") func() return wrapper @if_parser_enabled def test_comments(): assert parses_as( """ // This is a line comment! () """, UNIT ) assert parses_as( """ /* This is a block comment! This is still a block comment! */ () """, UNIT ) @if_parser_enabled def test_int_literal(): assert isinstance(relay.fromtext(SEMVER+"1"), relay.Constant) assert isinstance(relay.fromtext(SEMVER+"1").data, tvm.ndarray.NDArray) assert get_scalar(relay.fromtext(SEMVER+"1")) == 1 assert get_scalar(relay.fromtext(SEMVER+"10")) == 10 assert get_scalar(relay.fromtext(SEMVER+"0")) == 0 assert get_scalar(relay.fromtext(SEMVER+"-100")) == -100 assert get_scalar(relay.fromtext(SEMVER+"-05")) == -5 @if_parser_enabled def test_float_literal(): assert get_scalar(relay.fromtext(SEMVER+"1.0")) == 1.0 assert isclose(get_scalar(relay.fromtext(SEMVER+"1.56667")), 1.56667) assert get_scalar(relay.fromtext(SEMVER+"0.0")) == 0.0 assert get_scalar(relay.fromtext(SEMVER+"-10.0")) == -10.0 # scientific notation assert isclose(get_scalar(relay.fromtext(SEMVER+"1e-1")), 1e-1) assert get_scalar(relay.fromtext(SEMVER+"1e+1")) == 1e+1 assert isclose(get_scalar(relay.fromtext(SEMVER+"1E-1")), 1E-1) assert get_scalar(relay.fromtext(SEMVER+"1E+1")) == 1E+1 assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0e-1")), 1.0e-1) assert get_scalar(relay.fromtext(SEMVER+"1.0e+1")) == 1.0e+1 assert isclose(get_scalar(relay.fromtext(SEMVER+"1.0E-1")), 1.0E-1) assert get_scalar(relay.fromtext(SEMVER+"1.0E+1")) == 1.0E+1 @if_parser_enabled def test_bool_literal(): assert get_scalar(relay.fromtext(SEMVER+"True")) == True assert get_scalar(relay.fromtext(SEMVER+"False")) == False @if_parser_enabled def test_negative(): assert isinstance(relay.fromtext(SEMVER+"let %x = 1; -%x").body, relay.Call) assert get_scalar(relay.fromtext(SEMVER+"--10")) == 10 assert get_scalar(relay.fromtext(SEMVER+"---10")) == -10 @if_parser_enabled def test_bin_op(): for bin_op in BINARY_OPS.keys(): assert parses_as( "1 {} 1".format(bin_op), BINARY_OPS.get(bin_op)(relay.const(1), relay.const(1)) ) @if_parser_enabled def test_parens(): assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"(1 * 1) + 1")) assert not alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1"), relay.fromtext(SEMVER+"1 * (1 + 1)")) @if_parser_enabled def test_op_assoc(): assert alpha_equal(relay.fromtext(SEMVER+"1 * 1 + 1 < 1 == 1"), relay.fromtext(SEMVER+"(((1 * 1) + 1) < 1) == 1")) assert alpha_equal(relay.fromtext(SEMVER+"1 == 1 < 1 + 1 * 1"), relay.fromtext(SEMVER+"1 == (1 < (1 + (1 * 1)))")) @nottest @if_parser_enabled def test_vars(): # temp vars won't work b/c they start with a digit # # temp var # temp_var = relay.fromtext("%1") # assert isinstance(temp_var, relay.Var) # assert temp_var.name == "1" # var var = relay.fromtext(SEMVER+"let %foo = (); %foo") assert isinstance(var.body, relay.Var) assert var.body.name_hint == "foo" # global var global_var = relay.fromtext(SEMVER+"@foo") assert isinstance(global_var, relay.GlobalVar) assert global_var.name_hint == "foo" # operator id op = relay.fromtext(SEMVER+"foo") assert isinstance(op, relay.Op) assert op.name == "foo" @if_parser_enabled def test_let(): assert parses_as( "let %x = 1; ()", relay.Let( X, relay.const(1), UNIT ) ) assert parses_as( """ let %x = 1; let %y = 2; () """, relay.Let( X, relay.const(1), relay.Let( Y, relay.const(2), UNIT ) ) ) @if_parser_enabled def test_seq(): assert parses_as( "(); ()", relay.Let( _, UNIT, UNIT) ) assert parses_as( "let %_ = { 1 }; ()", relay.Let( X, relay.const(1), UNIT ) ) @if_parser_enabled def test_graph(): assert parses_as( "%0 = (); %1 = 1; (%0, %0, %1)", relay.Tuple([UNIT, UNIT, relay.const(1)]) ) assert not parses_as( "%0 = (); %1 = 1; (%0, %0, %1)", relay.Tuple([relay.Tuple([]), relay.Tuple([]), relay.const(1)]) ) @raises_parse_error @if_parser_enabled def test_graph_wrong_order(): relay.fromtext(SEMVER+"%1 = (); %1") @raises_parse_error @if_parser_enabled def test_let_global_var(): relay.fromtext(SEMVER+"let @x = 1; ()") @raises_parse_error @if_parser_enabled def test_let_op(): relay.fromtext(SEMVER+"let x = 1; ()") @if_parser_enabled def test_tuple(): assert parses_as("()", relay.Tuple([])) assert parses_as("(0,)", relay.Tuple([relay.const(0)])) assert parses_as("(0, 1)", relay.Tuple([relay.const(0), relay.const(1)])) assert parses_as("(0, 1, 2)", relay.Tuple([relay.const(0), relay.const(1), relay.const(2)])) @if_parser_enabled def test_func(): # 0 args assert parses_as( "fn () { 0 }", relay.Function( [], relay.const(0), None, [] ) ) # 1 arg assert parses_as( "fn (%x) { %x }", relay.Function( [X], X, None, [] ) ) # 2 args assert parses_as( "fn (%x, %y) { %x + %y }", relay.Function( [X, Y], relay.add(X, Y), None, [] ) ) # annotations assert parses_as( "fn (%x: int32) -> int32 { %x }", relay.Function( [X_ANNO], X_ANNO, int32, [] ) ) # attributes assert parses_as( "fn (n=5) { () }", relay.Function([], UNIT, None, None, tvm.make.node("DictAttrs", n=relay.const(5))) ) # TODO(@jmp): Crashes if %x isn't annnotated. @if_parser_enabled def test_defn(): id_defn = relay.fromtext( SEMVER+ """ def @id(%x: int32) -> int32 { %x } """) assert isinstance(id_defn, relay.Module) @if_parser_enabled def test_recursive_call(): id_defn = relay.fromtext( SEMVER+ """ def @id(%x: int32) -> int32 { @id(%x) } """) assert isinstance(id_defn, relay.Module) @if_parser_enabled def test_ifelse(): assert parses_as( """ if (True) { 0 } else { 1 } """, relay.If( relay.const(True), relay.const(0), relay.const(1) ) ) @raises_parse_error @if_parser_enabled def test_ifelse_scope(): relay.fromtext( SEMVER+ """ if (True) { let %x = (); () } else { %x } """ ) @if_parser_enabled def test_call(): # select right function to call: simple ident case id_func = relay.Var("id") assert parses_as( """ let %id = fn (%x) { %x }; 10 * %id(10) """, relay.Let( id_func, relay.Function([X], X, None, []), relay.multiply(relay.const(10), relay.Call(id_func, [relay.const(10)])) ) ) # 0 args constant = relay.Var("constant") assert parses_as( """ let %constant = fn () { 0 }; %constant() """, relay.Let( constant, relay.Function([], relay.const(0), None, []), relay.Call(constant, [], None, None) ) ) # 1 arg id_var = relay.Var("id") assert parses_as( """ let %id = fn (%x) { %x }; %id(1) """, relay.Let( id_var, relay.Function([X], X, None, []), relay.Call(id_var, [relay.const(1)], None, None) ) ) # 2 args multiply = relay.Var("multiply") assert parses_as( """ let %multiply = fn (%x, %y) { %x * %y }; %multiply(0, 0) """, relay.Let( multiply, relay.Function( [X, Y], relay.multiply(X, Y), None, [] ), relay.Call(multiply, [relay.const(0), relay.const(0)], None, None) ) ) # anonymous function assert parses_as( """ (fn (%x) { %x })(0) """, relay.Call( relay.Function( [X], X, None, [] ), [relay.const(0)], None, None ) ) # TODO(@jmp): re-enable after sequence parsing improvements # curried function # curried_mult = relay.Var("curried_mult") # assert parses_as( # """ # let %curried_mult = # fn (%x) { # fn (%y) { # %x * %y # } # }; # %curried_mult(0); # %curried_mult(0)(0) # """, # relay.Let( # curried_mult, # relay.Function( # [X], # relay.Function( # [Y], # relay.multiply(X, Y), # None, # [] # ), # None, # [] # ), # relay.Let( # _, # relay.Call(curried_mult, [relay.const(0)], None, None), # relay.Call(relay.Call(curried_mult, [relay.const(0)], None, None), [relay.const(0)], None, None) # ) # ) # ) # op assert parses_as( "abs(1)", relay.Call(relay.op.get("abs"), [relay.const(1)], None, None) ) # Types @if_parser_enabled def test_incomplete_type(): assert parses_as( "let %_ : _ = (); ()", relay.Let( _, UNIT, UNIT ) ) @if_parser_enabled def test_builtin_types(): for builtin_type in TYPES: relay.fromtext(SEMVER+"let %_ : {} = (); ()".format(builtin_type)) @nottest @if_parser_enabled def test_call_type(): assert False @if_parser_enabled def test_tensor_type(): assert parses_as( "let %_ : Tensor[(), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((), "float32")), UNIT, UNIT ) ) assert parses_as( "let %_ : Tensor[(1,), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1,), "float32")), UNIT, UNIT ) ) assert parses_as( "let %_ : Tensor[(1, 1), float32] = (); ()", relay.Let( relay.Var("_", relay.TensorType((1, 1), "float32")), UNIT, UNIT ) ) @if_parser_enabled def test_function_type(): assert parses_as( """ let %_: fn () -> int32 = fn () -> int32 { 0 }; () """, relay.Let( relay.Var("_", relay.FuncType([], int32, [], [])), relay.Function([], relay.const(0), int32, []), UNIT ) ) assert parses_as( """ let %_: fn (int32) -> int32 = fn (%x: int32) -> int32 { 0 }; () """, relay.Let( relay.Var("_", relay.FuncType([int32], int32, [], [])), relay.Function([relay.Var("x", int32)], relay.const(0), int32, []), UNIT ) ) assert parses_as( """ let %_: fn (int32, int32) -> int32 = fn (%x: int32, %y: int32) -> int32 { 0 }; () """, relay.Let( relay.Var("_", relay.FuncType([int32, int32], int32, [], [])), relay.Function([relay.Var("x", int32), relay.Var("y", int32)], relay.const(0), int32, []), UNIT ) ) @if_parser_enabled def test_tuple_type(): assert parses_as( """ let %_: () = (); () """, relay.Let( relay.Var("_", relay.TupleType([])), UNIT, UNIT ) ) assert parses_as( """ let %_: (int32,) = (0,); () """, relay.Let( relay.Var("_", relay.TupleType([int32])), relay.Tuple([relay.const(0)]), UNIT ) ) assert parses_as( """ let %_: (int32, int32) = (0, 1); () """, relay.Let( relay.Var("_", relay.TupleType([int32, int32])), relay.Tuple([relay.const(0), relay.const(1)]), UNIT ) )