test_debug.py 922 Bytes
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from tvm.relay import var, const, create_executor
from tvm.relay.op import debug


_test_debug_hit = False

def test_debug():
    global _test_debug_hit
    ex = create_executor()
    x = var('x', shape=(), dtype='int32')
    _test_debug_hit = False
    def did_exec(x):
        global _test_debug_hit
        _test_debug_hit = True
    prog = debug(x, debug_func=did_exec)
16
    result = ex.evaluate(prog, { x: const(1, 'int32') })
17 18 19
    assert _test_debug_hit
    assert result.asnumpy() == 1

20

21 22 23 24 25 26 27 28 29 30
def test_debug_with_expr():
    global _test_debug_hit
    _test_debug_hit = False
    ex = create_executor()
    x = var('x', shape=(), dtype='int32')
    _test_debug_hit = False
    def did_exec(x):
        global _test_debug_hit
        _test_debug_hit = True
    prog = debug(x + x * x, debug_func=did_exec)
31
    result = ex.evaluate(prog, { x: const(2, 'int32') })
32 33
    assert _test_debug_hit
    assert result.asnumpy() == 6