test_topi_transform.py 8.93 KB
Newer Older
1 2 3 4 5 6 7 8 9
"""Test code for broadcasting operators."""
import numpy as np
import tvm
import topi

def verify_expand_dims(in_shape, out_shape, axis, num_newaxis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.expand_dims(A, axis, num_newaxis)
    def check_device(device):
10 11
        ctx = tvm.context(device, 0)
        if not ctx.exist:
12 13
            print("Skip because %s is not enabled" % device)
            return
14
        print("Running on target: %s" % device)
15 16
        with tvm.target.create(device):
            s = topi.generic.schedule_broadcast(B)
17 18 19 20 21 22 23 24
        foo = tvm.build(s, [A, B], device, name="expand_dims")
        data_npy = np.random.uniform(size=in_shape).astype(A.dtype)
        out_npy = data_npy.reshape(out_shape)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.array(np.empty(out_shape).astype(B.dtype), ctx)
        foo(data_nd, out_nd)
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

25
    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
26
        check_device(device)
27 28


29 30 31 32
def verify_tranpose(in_shape, axes):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = topi.transpose(A, axes)
    def check_device(device):
33 34
        ctx = tvm.context(device, 0)
        if not ctx.exist:
35 36
            print("Skip because %s is not enabled" % device)
            return
37
        print("Running on target: %s" % device)
38 39
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
40 41 42 43 44 45 46 47
        foo = tvm.build(s, [A, B], device, name="tranpose")
        data_npy = np.arange(np.prod(in_shape)).reshape(in_shape).astype(A.dtype)
        out_npy = data_npy.transpose(axes)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

48
    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
49 50
        check_device(device)

51

52 53 54 55
def verify_reshape(src_shape, dst_shape):
    A = tvm.placeholder(shape=src_shape, name="A")
    B = topi.reshape(A, dst_shape)
    def check_device(device):
56 57
        ctx = tvm.context(device, 0)
        if not ctx.exist:
58 59
            print("Skip because %s is not enabled" % device)
            return
60
        print("Running on target: %s" % device)
61 62
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
63 64 65 66 67 68 69 70
        foo = tvm.build(s, [A, B], device, name="reshape")
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npy = np.reshape(data_npy, newshape=dst_shape)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nd = tvm.nd.empty(dst_shape, ctx=ctx, dtype=B.dtype)
        foo(data_nd, out_nd)
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

71
    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
72 73
        check_device(device)

74

75 76 77 78
def verify_squeeze(src_shape, axis):
    A = tvm.placeholder(shape=src_shape, name="A")
    B = topi.squeeze(A, axis=axis)
    def check_device(device):
79 80
        ctx = tvm.context(device, 0)
        if not ctx.exist:
81 82
            print("Skip because %s is not enabled" % device)
            return
83
        print("Running on target: %s" % device)
84 85
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(B)
86

87 88 89 90
        foo = tvm.build(s, [A, B], device, name="squeeze")
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npy = np.squeeze(data_npy, axis=axis)
        data_nd = tvm.nd.array(data_npy, ctx)
91 92 93 94 95
        if out_npy.shape == ():
            out_nd_shape = (1,)
        else:
            out_nd_shape = out_npy.shape
        out_nd = tvm.nd.empty(out_nd_shape, ctx=ctx, dtype=B.dtype)
96 97 98
        foo(data_nd, out_nd)
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

99
    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
100
        check_device(device)
101

102 103 104 105 106 107
def verify_concatenate(shapes, axis):
    tensor_l = []
    for i, shape in enumerate(shapes):
        tensor_l.append(tvm.placeholder(shape, name="A" + str(i)))
    out_tensor = topi.concatenate(a_tuple=tensor_l, axis=axis)
    def check_device(device):
108 109
        ctx = tvm.context(device, 0)
        if not ctx.exist:
110 111
            print("Skip because %s is not enabled" % device)
            return
112
        print("Running on target: %s" % device)
113 114
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(out_tensor)
115

116 117 118 119 120 121 122 123
        foo = tvm.build(s, tensor_l + [out_tensor], device, name="concatenate")
        data_npys = [np.random.normal(size=shape).astype(tensor_l[0].dtype) for shape in shapes]
        out_npy = np.concatenate(data_npys, axis=axis)
        data_nds = [tvm.nd.array(data_npy, ctx) for data_npy in data_npys]
        out_nd = tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=out_tensor.dtype)
        foo(*(data_nds + [out_nd]))
        np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

124
    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
125 126
        check_device(device)

127 128 129 130 131

def verify_split(src_shape, indices_or_sections, axis):
    A = tvm.placeholder(shape=src_shape, name="A")
    tensor_l = topi.split(A, indices_or_sections, axis=axis)
    def check_device(device):
132 133
        ctx = tvm.context(device, 0)
        if not ctx.exist:
134 135
            print("Skip because %s is not enabled" % device)
            return
136
        print("Running on target: %s" % device)
137 138
        with tvm.target.create(device):
            s = topi.generic.schedule_injective(tensor_l)
139

140 141 142 143 144 145 146 147 148
        foo = tvm.build(s, [A] + tensor_l, device, name="split")
        data_npy = np.random.normal(size=src_shape).astype(A.dtype)
        out_npys = np.split(data_npy, indices_or_sections, axis=axis)
        data_nd = tvm.nd.array(data_npy, ctx)
        out_nds = [tvm.nd.empty(out_npy.shape, ctx=ctx, dtype=tensor_l[0].dtype) for out_npy in out_npys]
        foo(*([data_nd] + out_nds))
        for out_nd, out_npy in zip(out_nds, out_npys):
            np.testing.assert_allclose(out_nd.asnumpy(), out_npy)

149
    for device in ["llvm", "nvptx", "cuda", "opencl", "metal", "rocm", "vulkan"]:
150 151
        check_device(device)

152

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187
def verify_expand_like(in_shape, out_shape, axis):
    A = tvm.placeholder(shape=in_shape, name="A")
    B = tvm.placeholder(shape=out_shape, name="B")
    C = topi.expand_like(A, B, axis)
    s = tvm.create_schedule([C.op])

    def check_device(device):
        if not tvm.module.enabled(device):
            print("Skip because %s is not enabled" % device)
            return
        print("Running on target: %s" % device)

        ctx = tvm.context(device, 0)
        f = tvm.build(s, [A, B, C], device, name="expand_like")
        input = np.random.uniform(size=in_shape).astype(A.dtype)
        tvm_input = tvm.nd.array(input, ctx)

        odim = len(out_shape)
        real_axis = [x if x >= 0 else x + odim for x in axis]
        real_axis = sorted(real_axis)
        for x in real_axis:
            input = np.expand_dims(input, x).astype(A.dtype)
        for x in real_axis:
            input = np.concatenate([input]*out_shape[x], axis=x).astype(A.dtype)
        assert input.shape == out_shape

        tvm_shape_like = tvm.nd.array(np.zeros(out_shape).astype(B.dtype), ctx)
        out = tvm.nd.array(np.zeros(out_shape).astype(A.dtype), ctx)
        f(tvm_input, tvm_shape_like, out)
        np.testing.assert_allclose(out.asnumpy(), input)

    for device in ["llvm"]:
        check_device(device)


188 189
def test_expand_dims():
    verify_expand_dims((3, 10), (3, 10, 1, 1), 2, 2)
190 191 192 193 194 195 196
    verify_expand_dims((3, 10), (1, 3, 10), -3, 1)


def test_tranpose():
    verify_tranpose((3, 10, 2), (1, 0, 2))
    verify_tranpose((3, 10, 5), (2, 0, 1))
    verify_tranpose((3, 10), None)
197 198


199 200 201 202 203 204 205
def test_reshape():
    verify_reshape((1, 2, 3, 4), (2, 3, 4))
    verify_reshape((4, 2, 3, 4), (2, 4, 12))
    verify_reshape((4, 2, 3, 4), (2, 48))
    verify_reshape((16, ), (2, 2, 2, 2))


206 207 208 209
def test_squeeze():
    verify_squeeze((1, 2, 3, 4), 0)
    verify_squeeze((1, 2, 1, 4), None)
    verify_squeeze((1, 1, 1, 4), (1, 2))
210
    verify_squeeze((1, 1, 1, 1), None)
211 212


213
def test_concatenate():
214
    verify_concatenate([(2,), (2,), (2,)], 0)
215 216 217 218 219 220 221 222 223 224 225 226 227 228
    verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1)
    verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1)
    verify_concatenate([(5, 6, 7, 3),
                        (16, 6, 7, 3),
                        (12, 6, 7, 3),
                        (8, 6, 7, 3),
                        (2, 6, 7, 3)], 0)


def test_split():
    verify_split((2, 12, 3), 3, 1)
    verify_split((2, 12, 3), [2, 4], 1)
    verify_split((10, 12, 24), [5, 7, 9], -1)

229 230 231 232 233 234 235 236

def test_expand_like():
    verify_expand_like((3,), (2, 3), [0])
    verify_expand_like((2,), (2, 3), [1])
    verify_expand_like((3, 4), (3, 5, 4), [1])
    verify_expand_like((5, 7), (5, 6, 7, 8), [1, 3])


237
if __name__ == "__main__":
238
    test_concatenate()
239
    test_tranpose()
240
    test_expand_dims()
241
    test_reshape()
242
    test_squeeze()
243
    test_split()
244
    test_expand_like()