test_topi_softmax.py 3.29 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21
"""Test code for softmax"""
import os
import numpy as np
import tvm
import topi
22
import topi.testing
23
import logging
24 25
from topi.util import get_const_tuple

26 27
from common import get_all_backend

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
def check_device(A, B, a_np, b_np, device, name):
    ctx = tvm.context(device, 0)
    if not ctx.exist:
        print("Skip because %s is not enabled" % device)
        return
    print("Running on target: %s" % device)
    with tvm.target.create(device):
        s = topi.generic.schedule_softmax(B)

    a = tvm.nd.array(a_np, ctx)
    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
    f = tvm.build(s, [A, B], device, name="softmax")
    f(a, b)
    tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

43 44
def verify_softmax(m, n, dtype="float32"):
    A = tvm.placeholder((m, n), dtype=dtype, name='A')
45
    B = topi.nn.softmax(A)
46 47 48 49
    # confirm lower works
    s = tvm.create_schedule([B.op])
    tvm.lower(s, [A, B], simple_mode=True)

50 51 52
    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
    b_np = topi.testing.softmax_python(a_np)

53 54 55 56 57 58
    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
        check_device(A, B, a_np, b_np, device, "softmax")

def verify_softmax_4d(shape, dtype="float32"):
    A = tvm.placeholder(shape, dtype=dtype, name='A')
    B = topi.nn.softmax(A, axis=1)
59

60 61 62 63
    _, c, h, w = shape
    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
    b_np = topi.testing.softmax_python(a_np.transpose(0, 2, 3, 1).reshape(h*w, c))
    b_np = b_np.reshape(1, h, w, c).transpose(0, 3, 1, 2)
64

65
    for device in ['cuda', 'opencl', 'metal', 'rocm', 'vulkan', 'nvptx']:
66
        check_device(A, B, a_np, b_np, device, "softmax")
67 68 69

def test_softmax():
    verify_softmax(32, 10)
70
    verify_softmax(3, 4)
71
    verify_softmax(32, 10, "float64")
72
    verify_softmax_4d((1, 16, 256, 256))
73

74 75
def verify_log_softmax(m, n, dtype="float32"):
    A = tvm.placeholder((m, n), dtype=dtype, name='A')
76 77 78 79 80 81 82
    B = topi.nn.log_softmax(A)
    # confirm lower works
    s = tvm.create_schedule([B.op])
    tvm.lower(s, [A, B], simple_mode=True)
    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
    b_np = topi.testing.log_softmax_python(a_np)

83
    for device in get_all_backend():
84
        check_device(A, B, a_np, b_np, device, "log_softmax")
85

86

87 88 89
def test_log_softmax():
    verify_log_softmax(32, 10)
    verify_log_softmax(3, 4)
90
    verify_log_softmax(32, 10, "float64")
91

92
if __name__ == "__main__":
93
    logging.basicConfig(level=logging.DEBUG)
94
    test_softmax()
95
    test_log_softmax()