test_pass_alpha_equal.py 22.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
import numpy as np
Zhi committed
18
import tvm
19
from tvm import relay
Zhi committed
20
from tvm.relay import analysis
21 22 23 24 25 26

def alpha_equal(x, y):
    """
    Wrapper around alpha equality which ensures that
    the hash function respects equality.
    """
Zhi committed
27
    return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
28

29
def test_tensor_type_alpha_equal():
30 31 32
    t1 = relay.TensorType((3, 4), "float32")
    t2 = relay.TensorType((3, 4), "float32")
    t3 = relay.TensorType((3, 4, 5), "float32")
33 34 35
    assert t1 == t2
    assert t1 != t3

36 37
    t1 = relay.TensorType((), "float32")
    t2 = relay.TensorType((), "float32")
38 39 40
    assert t1 == t2


41
def test_incomplete_type_alpha_equal():
42 43 44 45 46 47 48 49 50 51 52
    t1 = relay.IncompleteType(relay.Kind.Shape)
    t2 = relay.IncompleteType(relay.Kind.Type)
    t3 = relay.IncompleteType(relay.Kind.Type)

    # only equal when there is pointer equality
    assert t2 == t2
    assert t1 == t1
    assert t1 != t2
    assert t2 != t3


53
def test_type_param_alpha_equal():
54 55 56
    t1 = relay.TypeVar("v1", relay.Kind.Type)
    t2 = relay.TypeVar("v2", relay.Kind.Shape)
    t3 = relay.TypeVar("v3", relay.Kind.Type)
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74

    # only pointer equality and eq_map allow equal params
    assert t1 == t1
    assert t2 == t2
    assert t1 != t2 # different kind
    assert t1 != t3 # not in eq_map

    # function types are the only way to put type params
    # in eq map
    ft1 = relay.FuncType(tvm.convert([]), t1, tvm.convert([t1]), tvm.convert([]))
    ft2 = relay.FuncType(tvm.convert([]), t3, tvm.convert([t3]), tvm.convert([]))
    # actually an invalid type because t2 is wrong kind
    ft3 = relay.FuncType(tvm.convert([]), t2, tvm.convert([t2]), tvm.convert([]))

    assert ft1 == ft2
    assert ft1 != ft3 # kinds still do not match


75
def test_func_type_alpha_equal():
76 77 78
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")

79 80 81 82
    tp1 = relay.TypeVar("v1", relay.Kind.Type)
    tp2 = relay.TypeVar("v2", relay.Kind.Type)
    tp3 = relay.TypeVar("v3", relay.Kind.Shape)
    tp4 = relay.TypeVar("v3", relay.Kind.Shape)
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134

    broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
    identity = tvm.get_env_func("tvm.relay.type_relation.Identity")

    tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None)
    tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None)
    tr3 = relay.TypeRelation(identity, tvm.convert([tp1, tp3]), 1, None)

    ft = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([tr1]))
    translate_vars = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp2, tp4]),
                         tvm.convert([tr2]))
    assert ft == translate_vars

    different_args = relay.FuncType(tvm.convert([t1]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([tr1]))
    assert ft != different_args

    different_order = relay.FuncType(tvm.convert([t2, t1]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([tr1]))
    assert ft != different_order

    no_rel = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([]))
    assert ft != no_rel

    more_vars = relay.FuncType(tvm.convert([t1, t2]), tp2,
                         tvm.convert([tp1, tp2, tp3]),
                         tvm.convert([tr1]))
    assert ft != more_vars

    all_the_vars = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp1, tp2, tp3, tp4]),
                         tvm.convert([tr1, tr2]))
    assert ft != all_the_vars

    different_rel = relay.FuncType(tvm.convert([t1, t2]), tp1,
                                   tvm.convert([tp1, tp3]),
                                   tvm.convert([tr3]))
    assert ft != different_rel

    more_rels = relay.FuncType(tvm.convert([t1, t2]), tp1,
                                   tvm.convert([tp1, tp3]),
                                   tvm.convert([tr1, tr3]))
    assert ft != more_rels


135
def test_tuple_type_alpha_equal():
136 137
    t1 = relay.TensorType((1, 2, 3), "float32")
    t2 = relay.TensorType((1, 2, 3, 4), "float32")
138 139
    tp1 = relay.TypeVar("v1", relay.Kind.Type)
    tp2 = relay.TypeVar("v2", relay.Kind.Type)
140 141 142 143 144 145 146 147 148 149 150 151 152

    tup1 = relay.TupleType(tvm.convert([t1, t2, tp1]))
    tup2 = relay.TupleType(tvm.convert([t1, t2, tp1]))
    tup3 = relay.TupleType(tvm.convert([t2, t1, tp1]))
    tup4 = relay.TupleType(tvm.convert([t1, t2, tp2]))

    # as long as types are alpha-equal and in same order,
    # tuples should be alpha-equal
    assert tup1 == tup2
    assert tup1 != tup3
    assert tup1 != tup4


153
def test_type_relation_alpha_equal():
154 155 156 157 158 159 160 161 162 163
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")
    t3 = relay.TensorType((1, 2, 3, 4), "float32")

    # functions are compared only by pointer equality so
    # we need to be sure to use the same pointers
    broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
    identity = tvm.get_env_func("tvm.relay.type_relation.Identity")

    attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
164 165
    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
166 167 168 169 170 171 172

    tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
    same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
    diff_func = relay.TypeRelation(identity, tvm.convert([t1, t2]), 1, attr1)
    diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1)
    diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1)
    diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2)
173
    same_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1_same)
174 175 176 177 178 179 180 181 182 183

    bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1)
    diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2)

    # func, number of args, input count, and order should be the same
    assert tr == same
    assert tr != diff_func
    assert tr != diff_order
    assert tr != diff_args
    assert tr != diff_attr
184
    assert tr == same_attr
185 186 187 188
    assert tr != bigger

    assert bigger != diff_num_inputs

189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
def test_type_call_alpha_equal():
    h1 = relay.GlobalTypeVar("h1")
    h2 = relay.GlobalTypeVar("h2")
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")
    t3 = relay.TensorType((1, 2, 3, 4), "float32")
    t4 = relay.TensorType((), "float32")

    tc = relay.TypeCall(h1, [t1, t2, t3])
    same = relay.TypeCall(h1, [t1, t2, t3])

    different_func = relay.TypeCall(h2, [t1, t2, t3])
    different_arg = relay.TypeCall(h1, [t1, t2, t4])
    fewer_args = relay.TypeCall(h1, [t1, t2])
    more_args = relay.TypeCall(h1, [t1, t2, t3, t4])
    different_order_args = relay.TypeCall(h1, [t3, t2, t1])

    assert tc == same
    assert tc != different_func
    assert tc != fewer_args
    assert tc != more_args
    assert tc != different_order_args

212 213

def test_constant_alpha_equal():
214 215
    x = relay.const(1)
    y = relay.const(2)
216 217
    assert alpha_equal(x, x)
    assert not alpha_equal(x, y)
218
    assert alpha_equal(x, relay.const(1))
219 220 221 222 223 224 225 226 227 228 229


def test_var_alpha_equal():
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

    # normally only pointer equality
    assert alpha_equal(v1, v1)
    assert not alpha_equal(v1, v2)

    # let node allows for setting the eq_map
230 231 232
    l1 = relay.Let(v1, relay.const(1), v1)
    l2 = relay.Let(v2, relay.const(1), v2)
    l3 = relay.Let(v1, relay.const(1), v2)
233 234 235 236

    assert alpha_equal(l1, l2)
    assert not alpha_equal(l1, l3)

237 238 239 240 241 242 243 244
    # type annotations
    tt1 = relay.TensorType([], "int32")
    tt2 = relay.TensorType([], "int32")
    tt3 = relay.TensorType([], "int64")
    v3 = relay.Var("v3", tt1)
    v4 = relay.Var("v4", tt2)
    v5 = relay.Var("v5", tt3)

245 246 247
    l4 = relay.Let(v3, relay.const(1), v3)
    l5 = relay.Let(v4, relay.const(1), v4)
    l6 = relay.Let(v5, relay.const(1), v5)
248 249 250 251 252 253 254 255

    # same annotations
    assert alpha_equal(l4, l5)
    # different annotations
    assert not alpha_equal(l4, l6)
    # one null annotation
    assert not alpha_equal(l1, l4)

256 257 258 259 260 261 262 263 264 265 266

def test_global_var_alpha_equal():
    v1 = relay.GlobalVar("v1")
    v2 = relay.GlobalVar("v2")

    # only pointer equality suffices (smoke test)
    assert alpha_equal(v1, v1)
    assert not alpha_equal(v1, v2)


def test_tuple_alpha_equal():
267
    v0 = relay.Var("v0")
268 269 270 271 272 273
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

    # unit value is a valid tuple
    assert alpha_equal(relay.Tuple([]), relay.Tuple([]))

274 275
    tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
    same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
276 277 278 279

    assert alpha_equal(tup, same)

    # use the eq_map
280 281


282
    let_tup = relay.Let(v1, tup, v1)
283
    let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3),
284
                                            relay.Tuple([relay.const(4)])]),
285
                           v2)
286

287 288
    assert alpha_equal(let_tup, let_mapped)

289
    more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
290 291
    assert not alpha_equal(tup, more_fields)

292
    fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)])
293 294
    assert not alpha_equal(tup, fewer_fields)

295 296
    different_end = relay.Tuple([v1, relay.const(2), relay.const(3),
                           relay.Tuple([relay.const(5)])])
297 298
    assert not alpha_equal(tup, different_end)

299 300
    different_start = relay.Tuple([v2, relay.const(2), relay.const(3),
                                 relay.Tuple([relay.const(4)])])
301 302
    assert not alpha_equal(tup, different_start)

303 304
    longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3),
                                 relay.Tuple([relay.const(4), relay.const(5)])])
305 306 307
    assert not alpha_equal(tup, longer_at_end)


308 309 310 311 312 313
def test_tuple_get_item_alpha_equal():
    x = relay.Var('x')
    y = relay.Var('y')
    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
    assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
314

315 316 317 318 319 320

def test_function_alpha_equal():
    tt1 = relay.TensorType((1, 2, 3), "float32")
    tt2 = relay.TensorType((4, 5, 6), "int8")
    tt3 = relay.TupleType([tt1, tt2])

321 322 323 324 325 326
    v1 = relay.Var("v1", tt1)
    v2 = relay.Var("v2", tt2)
    v3 = relay.Var("v3", tt3)
    v4 = relay.Var("v4", tt2)
    vret = relay.Constant(tvm.nd.array(np.ones(1)))

327 328 329 330
    tp1 = relay.TypeVar("tp1", relay.Kind.Type)
    tp2 = relay.TypeVar("tp2", relay.Kind.Type)
    tp3 = relay.TypeVar("tp3", relay.Kind.Shape)
    tp4 = relay.TypeVar("tp4", relay.Kind.Shape)
331

332
    basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
333 334
    basic_tps = [tp1, tp2]

335 336 337
    func = relay.Function([v1, v2], v1,
                          tt2, basic_tps)
    mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
338 339
    assert alpha_equal(func, mapped)

340
    fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
341 342
    assert not alpha_equal(func, fewer_params)

343 344
    more_params = relay.Function([relay.Var("v3", tt1),
                                  relay.Var("v4", tt2),
345
                                  relay.Var("v2", tt2)], v4, tt2, basic_tps)
346 347
    assert not alpha_equal(func, more_params)

348 349
    params_unordered = relay.Function([v2, v1], v1,
                                      tt2, basic_tps)
350 351
    assert not alpha_equal(func, params_unordered)

352 353
    params_mismatch = relay.Function([v1, v3], v1,
                                     tt2, basic_tps)
354 355 356
    assert not alpha_equal(func, params_mismatch)

    # also would not typecheck
357
    ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
358 359 360
    assert not alpha_equal(func, ret_type_mismatch)

    # also mis-typed
361
    different_body = relay.Function(basic_args, v3, tt2, basic_tps)
362 363
    assert not alpha_equal(func, different_body)

364
    fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
365 366
    assert not alpha_equal(func, fewer_type_params)

367
    more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
368 369
    assert not alpha_equal(func, more_type_params)

370
    type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
371 372
    assert not alpha_equal(func, type_params_unordered)

373
    different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
374 375 376
    assert not alpha_equal(func, different_type_params)

    # a well-typed example that also differs in body, ret type, and type params
377
    tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
378 379
    assert not alpha_equal(func, tupled_example)

380 381 382 383 384 385 386 387
    # nullable
    no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2])
    # both null
    assert alpha_equal(no_ret_type, no_ret_type)
    # one null
    assert not alpha_equal(func, no_ret_type)
    assert not alpha_equal(no_ret_type, func)

388 389 390 391 392 393

def test_call_alpha_equal():
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

    attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
394 395
    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
396 397 398 399

    tt1 = relay.TensorType((1, 2, 3), "float32")
    tt2 = relay.TensorType((), "int8")

400
    basic_args = [relay.const(1), relay.const(2), v2, relay.Tuple([])]
401 402 403

    # manually writing out args to ensure that args does not rely on
    # pointer equality
404
    call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])],
405 406 407 408 409 410 411
                      attr1, [tt1])
    same = relay.Call(v1, basic_args, attr1, [tt1])
    assert alpha_equal(call, same)

    different_fn = relay.Call(v2, basic_args, attr1, [tt1])
    assert not alpha_equal(call, different_fn)

412
    fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1, [tt1])
413 414
    assert not alpha_equal(call, fewer_args)

415
    reordered_args = relay.Call(v1, [relay.const(2), relay.const(1),
416 417 418
                                     relay.Tuple([]), v2], attr1, [tt1])
    assert not alpha_equal(call, reordered_args)

419
    different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)],
420 421 422
                                attr1, [tt1])
    assert not alpha_equal(call, different_args)

423 424
    more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([]),
                                relay.const(3), relay.const(4)], attr1, [tt1])
425 426 427 428 429
    assert not alpha_equal(call, more_args)

    different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
    assert not alpha_equal(call, different_attrs)

430 431 432
    same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
    assert alpha_equal(call, same_attrs)

433 434 435 436 437 438 439 440 441 442 443
    no_type_args = relay.Call(v1, basic_args, attr1)
    assert not alpha_equal(call, no_type_args)

    more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2])
    assert not alpha_equal(call, more_type_args)

    different_type_arg = relay.Call(v1, basic_args, attr1, [tt2])
    assert not alpha_equal(call, different_type_arg)


def test_let_alpha_equal():
444 445
    tt1 = relay.TensorType((), "float32")
    tt2 = relay.TensorType((), "int8")
446
    v1 = relay.Var("v1")
447
    v1_wtype = relay.Var("v1", tt1)
448 449 450
    v2 = relay.Var("v2")
    v3 = relay.Var("v3")

451 452
    let = relay.Let(v1, relay.const(2), v1)
    mapped = relay.Let(v2, relay.const(2), v2)
453 454
    assert alpha_equal(let, mapped)

455
    mismatched_var = relay.Let(v2, relay.const(2), v3)
456 457
    assert not alpha_equal(let, mismatched_var)

458
    different_value = relay.Let(v2, relay.const(3), v2)
459 460
    assert not alpha_equal(let, different_value)

461
    different_body = relay.Let(v2, relay.const(3), relay.const(12))
462 463 464
    assert not alpha_equal(let, different_body)

    # specified types must match
465

466 467
    let_with_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
    same_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
468 469
    assert alpha_equal(let_with_type, same_type)
    assert not alpha_equal(let, let_with_type)
470
    v2 = relay.Var("v1", tt2)
471
    different_type = relay.Let(v2, relay.const(2), v2)
472 473 474 475 476 477 478
    assert not alpha_equal(let_with_type, different_type)


def test_if_alpha_equal():
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

479 480
    if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
    same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
481 482
    assert alpha_equal(if_sample, same)

483
    different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
484 485
    assert not alpha_equal(if_sample, different_cond)

486
    different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)]))
487 488
    assert not alpha_equal(if_sample, different_true)

489
    different_false = relay.If(v1, relay.const(1), relay.Tuple([]))
490 491 492
    assert not alpha_equal(if_sample, different_false)


493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521
def test_constructor_alpha_equal():
    # smoke test: it should be pointer equality
    mod = relay.Module()
    p = relay.prelude.Prelude(mod)

    assert alpha_equal(p.nil, p.nil)
    assert alpha_equal(p.cons, p.cons)
    assert not alpha_equal(p.nil, p.cons)


def test_match_alpha_equal():
    mod = relay.Module()
    p = relay.prelude.Prelude(mod)

    x = relay.Var('x')
    y = relay.Var('y')
    nil_case = relay.Clause(relay.PatternConstructor(p.nil), p.nil())
    cons_case = relay.Clause(relay.PatternConstructor(p.cons,
                                                      [relay.PatternVar(x),
                                                       relay.PatternVar(y)]),
                       p.cons(x, y))

    z = relay.Var('z')
    a = relay.Var('a')
    equivalent_cons = relay.Clause(relay.PatternConstructor(p.cons,
                                                            [relay.PatternVar(z),
                                                             relay.PatternVar(a)]),
                                   p.cons(z, a))

522
    data = p.cons(relay.const(1), p.cons(relay.const(2), p.nil()))
523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547

    match = relay.Match(data, [nil_case, cons_case])
    equivalent = relay.Match(data, [nil_case, equivalent_cons])
    empty = relay.Match(data, [])
    no_cons = relay.Match(data, [nil_case])
    no_nil = relay.Match(data, [cons_case])
    different_data = relay.Match(p.nil(), [nil_case, cons_case])
    different_order = relay.Match(data, [cons_case, nil_case])
    different_nil = relay.Match(data, [
        relay.Clause(relay.PatternConstructor(p.nil), p.cons(p.nil(), p.nil())),
        cons_case
    ])
    different_cons = relay.Match(data, [
        nil_case,
        relay.Clause(relay.PatternConstructor(p.cons,
                                              [relay.PatternWildcard(),
                                               relay.PatternWildcard()]),
                     p.nil())
    ])
    another_case = relay.Match(data, [
        nil_case,
        cons_case,
        relay.Clause(relay.PatternWildcard(), p.nil())
    ])
    wrong_constructors = relay.Match(data, [
548 549
        relay.Clause(relay.PatternConstructor(p.none), p.nil()),
        relay.Clause(relay.PatternConstructor(p.some, [relay.PatternVar(x)]),
550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565
                     p.cons(x, p.nil()))
    ])

    assert alpha_equal(match, match)
    assert alpha_equal(match, equivalent)
    assert not alpha_equal(match, no_cons)
    assert not alpha_equal(match, no_nil)
    assert not alpha_equal(match, empty)
    assert not alpha_equal(match, different_data)
    assert not alpha_equal(match, different_order)
    assert not alpha_equal(match, different_nil)
    assert not alpha_equal(match, different_cons)
    assert not alpha_equal(match, another_case)
    assert not alpha_equal(match, wrong_constructors)


566 567 568 569 570 571 572 573 574 575
def test_op_alpha_equal():
    # only checks names
    op1 = relay.op.get("add")
    op2 = relay.op.get("add")
    assert alpha_equal(op1, op2)

    op3 = relay.op.get("take")
    assert not alpha_equal(op1, op3)


576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594
def test_graph_equal():
    x = relay.var("x")

    y0 = relay.add(x, x)
    z0 = relay.add(y0, y0)

    y1 = relay.add(x, x)
    z1 = relay.add(y1, y1)

    z3 = relay.add(relay.add(x, x), relay.add(x, x))

    assert alpha_equal(z0, z1)

    # z3's dataflow format is different from z0
    # z0 is computed from a common y0 node
    # Relay view them as different programs
    # Check the difference in the text format.
    assert not alpha_equal(z0, z3)

595 596 597 598
def test_hash_unequal():
    x1 = relay.var("x1", shape=(10, 10), dtype="float32")
    y1 = relay.var("y1", shape=(10, 10), dtype="float32")
    func1 = relay.Function([x1, y1], relay.add(x1, y1))
599

600 601 602 603 604
    # func2 is exactly same structure with same variables shapes and dtypes
    x2 = relay.var("x2", shape=(10, 10), dtype="float32")
    y2 = relay.var("y2", shape=(10, 10), dtype="float32")
    func2 = relay.Function([x2, y2], relay.add(x2, y2))

Zhi committed
605
    assert analysis.structural_hash(func1) == analysis.structural_hash(func2)
606 607 608 609 610 611

    # func3 is same as func1 but with different var shapes
    x3 = relay.var("x3", shape=(20, 10), dtype="float32")
    y3 = relay.var("y3", shape=(20, 10), dtype="float32")
    func3 = relay.Function([x3, y3], relay.add(x3, y3))

Zhi committed
612
    assert not analysis.structural_hash(func1) == analysis.structural_hash(func3)
613

614
if __name__ == "__main__":
615 616
    test_tensor_type_alpha_equal()
    test_incomplete_type_alpha_equal()
617
    test_constant_alpha_equal()
618 619 620
    test_func_type_alpha_equal()
    test_tuple_type_alpha_equal()
    test_type_relation_alpha_equal()
621
    test_type_call_alpha_equal()
622 623 624
    test_constant_alpha_equal()
    test_global_var_alpha_equal()
    test_tuple_alpha_equal()
625
    test_tuple_get_item_alpha_equal()
626 627 628 629
    test_function_alpha_equal()
    test_call_alpha_equal()
    test_let_alpha_equal()
    test_if_alpha_equal()
630 631
    test_constructor_alpha_equal()
    test_match_alpha_equal()
632
    test_op_alpha_equal()
633
    test_var_alpha_equal()
634
    test_graph_equal()
635
    test_hash_unequal()