test_ext.py 1.23 KB
Newer Older
1 2
import tvm_ext
import tvm
3
import numpy as np
4 5 6 7 8 9 10

def test_bind_add():
    def add(a, b):
        return a + b
    f = tvm_ext.bind_add(add, 1)
    assert f(2)  == 3

11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
def test_ext_dev():
    n = 10
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)
    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        f = tvm.build(s, [A, B], "ext_dev", "llvm")
        ctx = tvm.ext_dev(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
        f(a, b)
        np.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
    check_llvm()


29 30 31 32 33 34
def test_sym_add():
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm_ext.sym_add(a, b)
    assert c.a == a and c.b == b

35 36 37 38 39 40 41 42 43 44 45 46
def test_ext_vec():
    ivec = tvm_ext.ivec_create(1, 2, 3)
    assert(isinstance(ivec, tvm_ext.IntVec))
    assert ivec[0] == 1
    assert ivec[1] == 2

    def ivec_cb(v2):
        assert(isinstance(v2, tvm_ext.IntVec))
        assert v2[2] == 3

    tvm.convert(ivec_cb)(ivec)

47
if __name__ == "__main__":
48
    test_ext_dev()
49
    test_ext_vec()
50 51
    test_bind_add()
    test_sym_add()