test_topi_math.py 4.61 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19
import numpy as np
import tvm
import topi
20
import topi.testing
21
from topi import util
22
from common import get_all_backend
23 24 25


def test_util():
26
    x = tvm.const(100, "int32")
27 28 29 30 31
    assert util.get_const_int(x) == 100
    assert util.get_const_tuple((x, x)) == (100, 100)


def test_ewise():
hlu1 committed
32 33 34 35 36 37 38 39 40 41 42 43 44 45
    def test_apply(
        func,
        name,
        f_numpy,
        low,
        high,
        shape=(20, 3),
        dtype=tvm.float32,
        check_round=False,
        skip_name_check=False,
    ):
        m = tvm.var("m")
        l = tvm.var("l")
        A = tvm.placeholder((m, l), dtype=dtype, name="A")
46 47 48

        B = func(A)
        assert tuple(B.shape) == tuple(A.shape)
49 50
        if not skip_name_check:
            assert B.op.body[0].name == name
51
        a_np = np.random.uniform(low=low, high=high, size=shape).astype(A.dtype) * 10
52 53 54
        # avoid round check too close to boundary
        if check_round:
            a_np += ((np.fmod(a_np, 1) - 0.5) < 1e-6) * 1e-5
55 56 57 58 59 60 61 62 63 64
        b_np = f_numpy(a_np)

        def check_device(device):
            ctx = tvm.context(device, 0)
            if not ctx.exist:
                print("Skip because %s is not enabled" % device)
                return
            print("Running on target: %s" % device)
            with tvm.target.create(device):
                s = topi.generic.schedule_injective(B)
65
            foo = tvm.build(s, [A, B], device, name=name)
66 67 68
            a = tvm.nd.array(a_np, ctx)
            b = tvm.nd.array(np.zeros_like(b_np), ctx)
            foo(a, b)
69
            tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5, atol=1e-5)
70

71
        for device in get_all_backend():
72
            check_device(device)
73

74 75
    test_apply(topi.floor, "floor", np.floor, -100, 100)
    test_apply(topi.ceil, "ceil", np.ceil, -100, 100)
76
    test_apply(topi.sign, "sign", np.sign, -100, 100, skip_name_check=True)
77
    test_apply(topi.trunc, "trunc", np.trunc, -100, 100)
78
    test_apply(topi.abs, "fabs", np.abs, -100, 100)
79
    test_apply(topi.round, "round", np.round, -100, 100, check_round=True)
80
    test_apply(topi.exp, "exp", np.exp, -1, 1)
hlu1 committed
81 82 83
    test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128))
    test_apply(topi.tanh, "tanh", np.tanh, -10, 10, shape=(128, 128), dtype="float64")
    test_apply(topi.sigmoid, "sigmoid", lambda x: 1 / (1 + np.exp(-x)), -1, 1)
84 85
    test_apply(topi.log, "log", np.log, 0, 100)
    test_apply(topi.sqrt, "sqrt", np.sqrt, 0, 100)
hlu1 committed
86
    test_apply(topi.rsqrt, "rsqrt", lambda x: np.ones_like(x) / np.sqrt(x), 0, 100, skip_name_check=True)
87

88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103

def test_cast():
    def verify(from_dtype, to_dtype, low=-100, high=100):
        shape = (5, 4)
        A = tvm.placeholder(shape, dtype=from_dtype, name="A")
        B = topi.cast(A, to_dtype)

        if from_dtype == "bool":
            a_np = np.random.choice([True, False], size=shape)
        else:
            a_np = np.random.uniform(low, high, size=shape).astype(from_dtype)
        if to_dtype == "bool":
            a_np = a_np - a_np[2, 3]
        b_np = a_np.astype(to_dtype)

        for device in get_all_backend():
hlu1 committed
104
            ctx = tvm.context(device, 0)
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
            if not ctx.exist:
                print("Skip because %s is not enabled" % device)
                continue
            print("Running on target: %s" % device)
            with tvm.target.create(device):
                s = topi.generic.schedule_injective(B)
            foo = tvm.build(s, [A, B], device)
            a = tvm.nd.array(a_np, ctx)
            b = tvm.nd.empty(shape=shape, dtype=to_dtype, ctx=ctx)
            foo(a, b)
            tvm.testing.assert_allclose(b.asnumpy(), b_np)

    verify("int32", "float32")
    verify("int32", "float64")
    verify("int32", "bool")
    verify("float32", "int32")
    verify("float32", "float64")
    verify("float32", "bool")
    verify("bool", "float32")
    verify("bool", "int32")


127 128 129
if __name__ == "__main__":
    test_util()
    test_ewise()
130
    test_cast()