test_runtime_packed_func.py 1.72 KB
Newer Older
1 2 3 4 5 6 7 8 9
import tvm
import numpy as np

def test_get_global():
    targs = (10, 10.0, "hello")
    # register into global function table
    @tvm.register_func
    def my_packed_func(*args):
        assert(tuple(args) == targs)
10
        return 10
11 12
    # get it out from global function table
    f = tvm.get_global_func("my_packed_func")
13
    assert isinstance(f, tvm.Function)
14 15 16
    y = f(*targs)
    assert y == 10

17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
def test_get_callback_with_node():
    x = tvm.convert(10)
    def test(y):
        assert y.handle != x.handle
        return y

    f2 = tvm.convert(test)
    # register into global function table
    @tvm.register_func
    def my_callback_with_node(y, f):
        assert y == x
        return f(y)

    # get it out from global function table
    f = tvm.get_global_func("my_callback_with_node")
    assert isinstance(f, tvm.Function)
    y = f(x, f2)
    assert(y.value == 10)

36 37 38 39 40 41 42 43 44

def test_return_func():
    def addy(y):
        def add(x):
            return tvm.convert(x + y)
        return add
    myf = tvm.convert(addy)
    f = myf(10)
    assert f(11).value == 21
45 46 47 48 49 50 51 52 53


def test_convert():
    # convert a function to tvm function
    targs = (10, 10.0, "hello", 10)
    def myfunc(*args):
        assert(tuple(args) == targs)

    f = tvm.convert(myfunc)
54
    assert isinstance(f, tvm.Function)
55

56 57 58 59 60 61 62 63
def test_byte_array():
    s = "hello"
    a = bytearray(s, encoding="ascii")

    def myfunc(ss):
        assert ss == a
    f = tvm.convert(myfunc)
    f(a)
64

65

66 67 68 69 70 71 72
def test_empty_array():
    def myfunc(ss):
        assert tuple(ss) == ()
    x = tvm.convert(())
    tvm.convert(myfunc)(x)


73
if __name__ == "__main__":
74
    test_empty_array()
75
    test_get_global()
76
    test_get_callback_with_node()
77
    test_convert()
78
    test_return_func()
79
    test_byte_array()