test_pass_alpha_equal.py 17.3 KB
Newer Older
1
import tvm
2
import numpy as np
3
from tvm import relay
4 5 6 7 8 9 10 11
from tvm.relay import ir_pass

def alpha_equal(x, y):
    """
    Wrapper around alpha equality which ensures that
    the hash function respects equality.
    """
    return ir_pass.alpha_equal(x, y) and ir_pass.structural_hash(x) == ir_pass.structural_hash(y)
12

13
def test_tensor_type_alpha_equal():
14 15 16
    t1 = relay.TensorType((3, 4), "float32")
    t2 = relay.TensorType((3, 4), "float32")
    t3 = relay.TensorType((3, 4, 5), "float32")
17 18 19
    assert t1 == t2
    assert t1 != t3

20 21
    t1 = relay.TensorType((), "float32")
    t2 = relay.TensorType((), "float32")
22 23 24
    assert t1 == t2


25
def test_incomplete_type_alpha_equal():
26 27 28 29 30 31 32 33 34 35 36
    t1 = relay.IncompleteType(relay.Kind.Shape)
    t2 = relay.IncompleteType(relay.Kind.Type)
    t3 = relay.IncompleteType(relay.Kind.Type)

    # only equal when there is pointer equality
    assert t2 == t2
    assert t1 == t1
    assert t1 != t2
    assert t2 != t3


37
def test_type_param_alpha_equal():
38 39 40
    t1 = relay.TypeVar("v1", relay.Kind.Type)
    t2 = relay.TypeVar("v2", relay.Kind.Shape)
    t3 = relay.TypeVar("v3", relay.Kind.Type)
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58

    # only pointer equality and eq_map allow equal params
    assert t1 == t1
    assert t2 == t2
    assert t1 != t2 # different kind
    assert t1 != t3 # not in eq_map

    # function types are the only way to put type params
    # in eq map
    ft1 = relay.FuncType(tvm.convert([]), t1, tvm.convert([t1]), tvm.convert([]))
    ft2 = relay.FuncType(tvm.convert([]), t3, tvm.convert([t3]), tvm.convert([]))
    # actually an invalid type because t2 is wrong kind
    ft3 = relay.FuncType(tvm.convert([]), t2, tvm.convert([t2]), tvm.convert([]))

    assert ft1 == ft2
    assert ft1 != ft3 # kinds still do not match


59
def test_func_type_alpha_equal():
60 61 62
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")

63 64 65 66
    tp1 = relay.TypeVar("v1", relay.Kind.Type)
    tp2 = relay.TypeVar("v2", relay.Kind.Type)
    tp3 = relay.TypeVar("v3", relay.Kind.Shape)
    tp4 = relay.TypeVar("v3", relay.Kind.Shape)
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118

    broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
    identity = tvm.get_env_func("tvm.relay.type_relation.Identity")

    tr1 = relay.TypeRelation(broadcast, tvm.convert([tp1, tp3]), 1, None)
    tr2 = relay.TypeRelation(broadcast, tvm.convert([tp2, tp4]), 1, None)
    tr3 = relay.TypeRelation(identity, tvm.convert([tp1, tp3]), 1, None)

    ft = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([tr1]))
    translate_vars = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp2, tp4]),
                         tvm.convert([tr2]))
    assert ft == translate_vars

    different_args = relay.FuncType(tvm.convert([t1]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([tr1]))
    assert ft != different_args

    different_order = relay.FuncType(tvm.convert([t2, t1]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([tr1]))
    assert ft != different_order

    no_rel = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp1, tp3]),
                         tvm.convert([]))
    assert ft != no_rel

    more_vars = relay.FuncType(tvm.convert([t1, t2]), tp2,
                         tvm.convert([tp1, tp2, tp3]),
                         tvm.convert([tr1]))
    assert ft != more_vars

    all_the_vars = relay.FuncType(tvm.convert([t1, t2]), tp1,
                         tvm.convert([tp1, tp2, tp3, tp4]),
                         tvm.convert([tr1, tr2]))
    assert ft != all_the_vars

    different_rel = relay.FuncType(tvm.convert([t1, t2]), tp1,
                                   tvm.convert([tp1, tp3]),
                                   tvm.convert([tr3]))
    assert ft != different_rel

    more_rels = relay.FuncType(tvm.convert([t1, t2]), tp1,
                                   tvm.convert([tp1, tp3]),
                                   tvm.convert([tr1, tr3]))
    assert ft != more_rels


119
def test_tuple_type_alpha_equal():
120 121
    t1 = relay.TensorType((1, 2, 3), "float32")
    t2 = relay.TensorType((1, 2, 3, 4), "float32")
122 123
    tp1 = relay.TypeVar("v1", relay.Kind.Type)
    tp2 = relay.TypeVar("v2", relay.Kind.Type)
124 125 126 127 128 129 130 131 132 133 134 135 136

    tup1 = relay.TupleType(tvm.convert([t1, t2, tp1]))
    tup2 = relay.TupleType(tvm.convert([t1, t2, tp1]))
    tup3 = relay.TupleType(tvm.convert([t2, t1, tp1]))
    tup4 = relay.TupleType(tvm.convert([t1, t2, tp2]))

    # as long as types are alpha-equal and in same order,
    # tuples should be alpha-equal
    assert tup1 == tup2
    assert tup1 != tup3
    assert tup1 != tup4


137
def test_type_relation_alpha_equal():
138 139 140 141 142 143 144 145 146 147 148
    t1 = relay.TensorType((1, 2), "float32")
    t2 = relay.TensorType((1, 2, 3), "float32")
    t3 = relay.TensorType((1, 2, 3, 4), "float32")

    # functions are compared only by pointer equality so
    # we need to be sure to use the same pointers
    broadcast = tvm.get_env_func("tvm.relay.type_relation.Broadcast")
    identity = tvm.get_env_func("tvm.relay.type_relation.Identity")

    # attrs are also compared only by pointer equality
    attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
149 150
    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
151 152 153 154 155 156 157

    tr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
    same = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1)
    diff_func = relay.TypeRelation(identity, tvm.convert([t1, t2]), 1, attr1)
    diff_order = relay.TypeRelation(broadcast, tvm.convert([t2, t1]), 1, attr1)
    diff_args = relay.TypeRelation(broadcast, tvm.convert([t2, t3]), 1, attr1)
    diff_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr2)
158
    same_attr = relay.TypeRelation(broadcast, tvm.convert([t1, t2]), 1, attr1_same)
159 160 161 162 163 164 165 166 167 168

    bigger = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 2, attr1)
    diff_num_inputs = relay.TypeRelation(identity, tvm.convert([t1, t3, t2]), 1, attr2)

    # func, number of args, input count, and order should be the same
    assert tr == same
    assert tr != diff_func
    assert tr != diff_order
    assert tr != diff_args
    assert tr != diff_attr
169
    assert tr == same_attr
170 171 172 173
    assert tr != bigger

    assert bigger != diff_num_inputs

174 175

def test_constant_alpha_equal():
176 177
    x = relay.const(1)
    y = relay.const(2)
178 179
    assert alpha_equal(x, x)
    assert not alpha_equal(x, y)
180
    assert alpha_equal(x, relay.const(1))
181 182 183 184 185 186 187 188 189 190 191


def test_var_alpha_equal():
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

    # normally only pointer equality
    assert alpha_equal(v1, v1)
    assert not alpha_equal(v1, v2)

    # let node allows for setting the eq_map
192 193 194
    l1 = relay.Let(v1, relay.const(1), v1)
    l2 = relay.Let(v2, relay.const(1), v2)
    l3 = relay.Let(v1, relay.const(1), v2)
195 196 197 198

    assert alpha_equal(l1, l2)
    assert not alpha_equal(l1, l3)

199 200 201 202 203 204 205 206
    # type annotations
    tt1 = relay.TensorType([], "int32")
    tt2 = relay.TensorType([], "int32")
    tt3 = relay.TensorType([], "int64")
    v3 = relay.Var("v3", tt1)
    v4 = relay.Var("v4", tt2)
    v5 = relay.Var("v5", tt3)

207 208 209
    l4 = relay.Let(v3, relay.const(1), v3)
    l5 = relay.Let(v4, relay.const(1), v4)
    l6 = relay.Let(v5, relay.const(1), v5)
210 211 212 213 214 215 216 217

    # same annotations
    assert alpha_equal(l4, l5)
    # different annotations
    assert not alpha_equal(l4, l6)
    # one null annotation
    assert not alpha_equal(l1, l4)

218 219 220 221 222 223 224 225 226 227 228

def test_global_var_alpha_equal():
    v1 = relay.GlobalVar("v1")
    v2 = relay.GlobalVar("v2")

    # only pointer equality suffices (smoke test)
    assert alpha_equal(v1, v1)
    assert not alpha_equal(v1, v2)


def test_tuple_alpha_equal():
229
    v0 = relay.Var("v0")
230 231 232 233 234 235
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

    # unit value is a valid tuple
    assert alpha_equal(relay.Tuple([]), relay.Tuple([]))

236 237
    tup = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
    same = relay.Tuple([v0, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)])])
238 239 240 241

    assert alpha_equal(tup, same)

    # use the eq_map
242 243


244
    let_tup = relay.Let(v1, tup, v1)
245
    let_mapped = relay.Let(v2, relay.Tuple([v0, relay.const(2), relay.const(3),
246
                                            relay.Tuple([relay.const(4)])]),
247
                           v2)
248

249 250
    assert alpha_equal(let_tup, let_mapped)

251
    more_fields = relay.Tuple([v1, relay.const(2), relay.const(3), relay.Tuple([relay.const(4)]), v2])
252 253
    assert not alpha_equal(tup, more_fields)

254
    fewer_fields = relay.Tuple([v1, relay.const(2), relay.const(3)])
255 256
    assert not alpha_equal(tup, fewer_fields)

257 258
    different_end = relay.Tuple([v1, relay.const(2), relay.const(3),
                           relay.Tuple([relay.const(5)])])
259 260
    assert not alpha_equal(tup, different_end)

261 262
    different_start = relay.Tuple([v2, relay.const(2), relay.const(3),
                                 relay.Tuple([relay.const(4)])])
263 264
    assert not alpha_equal(tup, different_start)

265 266
    longer_at_end = relay.Tuple([v1, relay.const(2), relay.const(3),
                                 relay.Tuple([relay.const(4), relay.const(5)])])
267 268 269
    assert not alpha_equal(tup, longer_at_end)


270 271 272 273 274 275
def test_tuple_get_item_alpha_equal():
    x = relay.Var('x')
    y = relay.Var('y')
    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
    assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
    assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))
276

277 278 279 280 281 282

def test_function_alpha_equal():
    tt1 = relay.TensorType((1, 2, 3), "float32")
    tt2 = relay.TensorType((4, 5, 6), "int8")
    tt3 = relay.TupleType([tt1, tt2])

283 284 285 286 287 288
    v1 = relay.Var("v1", tt1)
    v2 = relay.Var("v2", tt2)
    v3 = relay.Var("v3", tt3)
    v4 = relay.Var("v4", tt2)
    vret = relay.Constant(tvm.nd.array(np.ones(1)))

289 290 291 292
    tp1 = relay.TypeVar("tp1", relay.Kind.Type)
    tp2 = relay.TypeVar("tp2", relay.Kind.Type)
    tp3 = relay.TypeVar("tp3", relay.Kind.Shape)
    tp4 = relay.TypeVar("tp4", relay.Kind.Shape)
293

294
    basic_args = [relay.Var("v3", tt1), relay.Var("v4", tt2)]
295 296
    basic_tps = [tp1, tp2]

297 298 299
    func = relay.Function([v1, v2], v1,
                          tt2, basic_tps)
    mapped = relay.Function(basic_args, basic_args[0], tt2, basic_tps)
300 301
    assert alpha_equal(func, mapped)

302
    fewer_params = relay.Function([relay.Var("v4", tt2)], v4, tt2, basic_tps)
303 304
    assert not alpha_equal(func, fewer_params)

305 306
    more_params = relay.Function([relay.Var("v3", tt1),
                                  relay.Var("v4", tt2),
307
                                  relay.Var("v2", tt2)], v4, tt2, basic_tps)
308 309
    assert not alpha_equal(func, more_params)

310 311
    params_unordered = relay.Function([v2, v1], v1,
                                      tt2, basic_tps)
312 313
    assert not alpha_equal(func, params_unordered)

314 315
    params_mismatch = relay.Function([v1, v3], v1,
                                     tt2, basic_tps)
316 317 318
    assert not alpha_equal(func, params_mismatch)

    # also would not typecheck
319
    ret_type_mismatch = relay.Function(basic_args, v4, tt1, basic_tps)
320 321 322
    assert not alpha_equal(func, ret_type_mismatch)

    # also mis-typed
323
    different_body = relay.Function(basic_args, v3, tt2, basic_tps)
324 325
    assert not alpha_equal(func, different_body)

326
    fewer_type_params = relay.Function(basic_args, v4, tt2, [tp1])
327 328
    assert not alpha_equal(func, fewer_type_params)

329
    more_type_params = relay.Function(basic_args, v4, tt2, [tp1, tp2, tp3])
330 331
    assert not alpha_equal(func, more_type_params)

332
    type_params_unordered = relay.Function(basic_args, v4, tt2, [tp2, tp1])
333 334
    assert not alpha_equal(func, type_params_unordered)

335
    different_type_params = relay.Function(basic_args, v4, tt2, [tp3, tp4])
336 337 338
    assert not alpha_equal(func, different_type_params)

    # a well-typed example that also differs in body, ret type, and type params
339
    tupled_example = relay.Function(basic_args, relay.Tuple([v3, v4]), tt3)
340 341
    assert not alpha_equal(func, tupled_example)

342 343 344 345 346 347 348 349
    # nullable
    no_ret_type = relay.Function(basic_args, v4, None, [tp1, tp2])
    # both null
    assert alpha_equal(no_ret_type, no_ret_type)
    # one null
    assert not alpha_equal(func, no_ret_type)
    assert not alpha_equal(no_ret_type, func)

350 351 352 353 354 355 356

def test_call_alpha_equal():
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

    # attrs are compared only by pointer equality
    attr1 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
357 358
    attr1_same = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4))
    attr2 = tvm.make.node("attrs.TestAttrs", name="attr", padding=(3,4,4))
359 360 361 362

    tt1 = relay.TensorType((1, 2, 3), "float32")
    tt2 = relay.TensorType((), "int8")

363
    basic_args = [relay.const(1), relay.const(2), v2, relay.Tuple([])]
364 365 366

    # manually writing out args to ensure that args does not rely on
    # pointer equality
367
    call = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([])],
368 369 370 371 372 373 374
                      attr1, [tt1])
    same = relay.Call(v1, basic_args, attr1, [tt1])
    assert alpha_equal(call, same)

    different_fn = relay.Call(v2, basic_args, attr1, [tt1])
    assert not alpha_equal(call, different_fn)

375
    fewer_args = relay.Call(v1, [relay.const(1), relay.const(2), v2], attr1, [tt1])
376 377
    assert not alpha_equal(call, fewer_args)

378
    reordered_args = relay.Call(v1, [relay.const(2), relay.const(1),
379 380 381
                                     relay.Tuple([]), v2], attr1, [tt1])
    assert not alpha_equal(call, reordered_args)

382
    different_args = relay.Call(v1, [relay.const(1), relay.const(2), relay.const(3)],
383 384 385
                                attr1, [tt1])
    assert not alpha_equal(call, different_args)

386 387
    more_args = relay.Call(v1, [relay.const(1), relay.const(2), v2, relay.Tuple([]),
                                relay.const(3), relay.const(4)], attr1, [tt1])
388 389 390 391 392
    assert not alpha_equal(call, more_args)

    different_attrs = relay.Call(v1, basic_args, attr2, [tt1])
    assert not alpha_equal(call, different_attrs)

393 394 395
    same_attrs = relay.Call(v1, basic_args, attr1_same, [tt1])
    assert alpha_equal(call, same_attrs)

396 397 398 399 400 401 402 403 404 405 406
    no_type_args = relay.Call(v1, basic_args, attr1)
    assert not alpha_equal(call, no_type_args)

    more_type_args = relay.Call(v1, basic_args, attr1, [tt1, tt2])
    assert not alpha_equal(call, more_type_args)

    different_type_arg = relay.Call(v1, basic_args, attr1, [tt2])
    assert not alpha_equal(call, different_type_arg)


def test_let_alpha_equal():
407 408
    tt1 = relay.TensorType((), "float32")
    tt2 = relay.TensorType((), "int8")
409
    v1 = relay.Var("v1")
410
    v1_wtype = relay.Var("v1", tt1)
411 412 413
    v2 = relay.Var("v2")
    v3 = relay.Var("v3")

414 415
    let = relay.Let(v1, relay.const(2), v1)
    mapped = relay.Let(v2, relay.const(2), v2)
416 417
    assert alpha_equal(let, mapped)

418
    mismatched_var = relay.Let(v2, relay.const(2), v3)
419 420
    assert not alpha_equal(let, mismatched_var)

421
    different_value = relay.Let(v2, relay.const(3), v2)
422 423
    assert not alpha_equal(let, different_value)

424
    different_body = relay.Let(v2, relay.const(3), relay.const(12))
425 426 427
    assert not alpha_equal(let, different_body)

    # specified types must match
428

429 430
    let_with_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
    same_type = relay.Let(v1_wtype, relay.const(2), v1_wtype)
431 432
    assert alpha_equal(let_with_type, same_type)
    assert not alpha_equal(let, let_with_type)
433
    v2 = relay.Var("v1", tt2)
434
    different_type = relay.Let(v2, relay.const(2), v2)
435 436 437 438 439 440 441
    assert not alpha_equal(let_with_type, different_type)


def test_if_alpha_equal():
    v1 = relay.Var("v1")
    v2 = relay.Var("v2")

442 443
    if_sample = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
    same = relay.If(v1, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
444 445
    assert alpha_equal(if_sample, same)

446
    different_cond = relay.If(v2, relay.const(1), relay.Tuple([relay.const(2), relay.const(3)]))
447 448
    assert not alpha_equal(if_sample, different_cond)

449
    different_true = relay.If(v1, relay.const(2), relay.Tuple([relay.const(2), relay.const(3)]))
450 451
    assert not alpha_equal(if_sample, different_true)

452
    different_false = relay.If(v1, relay.const(1), relay.Tuple([]))
453 454 455 456 457 458 459 460 461 462 463 464 465
    assert not alpha_equal(if_sample, different_false)


def test_op_alpha_equal():
    # only checks names
    op1 = relay.op.get("add")
    op2 = relay.op.get("add")
    assert alpha_equal(op1, op2)

    op3 = relay.op.get("take")
    assert not alpha_equal(op1, op3)


466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
def test_graph_equal():
    x = relay.var("x")

    y0 = relay.add(x, x)
    z0 = relay.add(y0, y0)

    y1 = relay.add(x, x)
    z1 = relay.add(y1, y1)

    z3 = relay.add(relay.add(x, x), relay.add(x, x))

    assert alpha_equal(z0, z1)

    # z3's dataflow format is different from z0
    # z0 is computed from a common y0 node
    # Relay view them as different programs
    # Check the difference in the text format.
    assert not alpha_equal(z0, z3)



487
if __name__ == "__main__":
488 489
    test_tensor_type_alpha_equal()
    test_incomplete_type_alpha_equal()
490
    test_constant_alpha_equal()
491 492 493
    test_func_type_alpha_equal()
    test_tuple_type_alpha_equal()
    test_type_relation_alpha_equal()
494 495 496
    test_constant_alpha_equal()
    test_global_var_alpha_equal()
    test_tuple_alpha_equal()
497
    test_tuple_get_item_alpha_equal()
498 499 500 501 502
    test_function_alpha_equal()
    test_call_alpha_equal()
    test_let_alpha_equal()
    test_if_alpha_equal()
    test_op_alpha_equal()
503
    test_var_alpha_equal()
504
    test_graph_equal()