test_ext.py 3.13 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
import tvm_ext
import tvm
19
import numpy as np
20 21 22 23 24 25 26

def test_bind_add():
    def add(a, b):
        return a + b
    f = tvm_ext.bind_add(add, 1)
    assert f(2)  == 3

27 28 29 30 31 32 33 34 35 36 37 38 39 40
def test_ext_dev():
    n = 10
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute((n,), lambda *i: A(*i) + 1.0, name='B')
    s = tvm.create_schedule(B.op)
    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        f = tvm.build(s, [A, B], "ext_dev", "llvm")
        ctx = tvm.ext_dev(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
        f(a, b)
41
        tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
42 43 44
    check_llvm()


45 46 47 48 49 50
def test_sym_add():
    a = tvm.var('a')
    b = tvm.var('b')
    c = tvm_ext.sym_add(a, b)
    assert c.a == a and c.b == b

51

52 53 54 55 56 57 58 59 60 61 62 63
def test_ext_vec():
    ivec = tvm_ext.ivec_create(1, 2, 3)
    assert(isinstance(ivec, tvm_ext.IntVec))
    assert ivec[0] == 1
    assert ivec[1] == 2

    def ivec_cb(v2):
        assert(isinstance(v2, tvm_ext.IntVec))
        assert v2[2] == 3

    tvm.convert(ivec_cb)(ivec)

64

65 66 67 68 69
def test_extract_ext():
    fdict = tvm.extract_ext_funcs(tvm_ext._LIB.TVMExtDeclare)
    assert fdict["mul"](3, 4) == 12


70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
def test_extern_call():
    n = 10
    A = tvm.placeholder((n,), name='A')
    B = tvm.compute((n,), lambda *i: tvm.call_extern("float32", "TVMTestAddOne", A(*i)), name='B')
    s = tvm.create_schedule(B.op)

    def check_llvm():
        if not tvm.module.enabled("llvm"):
            return
        f = tvm.build(s, [A, B], "llvm")
        ctx = tvm.cpu(0)
        # launch the kernel.
        a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), ctx)
        b = tvm.nd.array(np.zeros(n, dtype=B.dtype), ctx)
        f(a, b)
        tvm.testing.assert_allclose(b.asnumpy(), a.asnumpy() + 1)
    check_llvm()


89 90 91 92 93 94 95 96 97 98 99 100 101
def test_nd_subclass():
    a = tvm_ext.NDSubClass.create(addtional_info=3)
    b = tvm_ext.NDSubClass.create(addtional_info=5)
    c = a + b
    d = a + a
    e = b + b
    assert(a.addtional_info == 3)
    assert(b.addtional_info == 5)
    assert(c.addtional_info == 8)
    assert(d.addtional_info == 6)
    assert(e.addtional_info == 10)


102
if __name__ == "__main__":
103
    test_nd_subclass()
104
    test_extern_call()
105
    test_ext_dev()
106
    test_ext_vec()
107 108
    test_bind_add()
    test_sym_add()
109
    test_extract_ext()