test_topi_relu.py 2.97 KB
Newer Older
1 2 3 4 5 6 7
"""Test code for relu activation"""
import os
import numpy as np
import tvm
import topi
from topi.util import get_const_tuple

8 9
def verify_relu(m, n, dtype):
    A = tvm.placeholder((m, n), name='A', dtype=dtype)
10
    B = topi.cpp.nn.relu(A)
11
    assert B.dtype == dtype
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
    b_np = a_np * (a_np > 0)

    def check_device(device):
        if not tvm.module.enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)
        target = topi.cpp.TEST_create_target(device)
        if device == "llvm":
            s = topi.cpp.generic.schedule_injective(target, [B])
        else:
            s = topi.cpp.cuda.schedule_injective(target, [B])
        ctx = tvm.context(device, 0)
        a = tvm.nd.array(a_np, ctx)
        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
        foo = tvm.build(s, [A, B], device, name="relu")
        foo(a, b)
        np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

    for device in ['cuda', 'opencl', 'metal', 'rocm']:
        check_device(device)


def verify_leaky_relu(m, alpha):
    A = tvm.placeholder((m,), name='A')
    B = topi.cpp.nn.leaky_relu(A, alpha)
    device = "llvm"
    target = topi.cpp.TEST_create_target(device)
    s = topi.cpp.generic.schedule_injective(target, [B])

    a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype)
    b_np = a_np * (a_np > 0) + a_np * (a_np < 0) * alpha
    ctx = tvm.cpu(0)
    a = tvm.nd.array(a_np, ctx)
    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
    foo = tvm.build(s, [A, B], device, name="leaky_relu")
    foo(a, b)
    np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
def verify_prelu(x, w):
    X = tvm.placeholder((x), name='X')
    W = tvm.placeholder((w), name='W')
    x_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(X.shape)).astype(X.dtype)
    w_np = np.random.uniform(low=-1.0, high=1.0, size=get_const_tuple(W.shape)).astype(W.dtype)
    def _prelu_numpy(x, W):
        return (x < 0) * (x *W.reshape(3, 1, 1)) + (x>=0) * x

    out_np = _prelu_numpy(x_np, w_np)
    B = topi.cpp.nn.prelu(X, W)
    device = "llvm"
    target = topi.cpp.TEST_create_target(device)
    s = topi.cpp.generic.schedule_injective(target, [B])

    ctx = tvm.cpu(0)
    x_tvm = tvm.nd.array(x_np, ctx)
    w_tvm = tvm.nd.array(w_np, ctx)

    b = tvm.nd.array(np.zeros(get_const_tuple(X.shape), dtype=B.dtype), ctx)
    foo = tvm.build(s, [X, W, B], "llvm", name="prelu")
    foo(x_tvm, w_tvm, b)
    np.testing.assert_allclose(b.asnumpy(), out_np, rtol=1e-5)
75 76

def test_relu():
77 78
    for dtype in ['float32', 'float64', 'int32', 'int16', 'int8', 'int64']:
        verify_relu(10, 128, dtype)
79 80 81 82

def test_leaky_relu():
    verify_leaky_relu(100, 0.1)

83 84
def test_prelu():
    verify_prelu((1, 3, 2, 2), (3,))
85 86 87 88

if __name__ == "__main__":
    test_relu()
    test_leaky_relu()
89
    test_prelu()