test_hybrid_script.py 10.7 KB
Newer Older
1
import tvm, inspect, sys, traceback, numpy, nose
2 3 4
from tvm.hybrid import script
from tvm.hybrid.intrin import HYBRID_GLOBALS

5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
@nose.tools.nottest
def run_and_check(func, args, outs, var_dict={}, target='llvm'):
    def tvm_val_2_py_val(val):
        val = tvm.ir_pass.Substitute(val, var_dict)
        val = tvm.ir_pass.Simplify(val)
        assert isinstance(val, (tvm.expr.IntImm, tvm.expr.UIntImm))
        return val.value

    ctx = tvm.context(target, 0)

    emu_args = []
    nd_args = []
    to_check = []
    for i in args:
        if isinstance(i, tvm.tensor.Tensor):
            shape = [tvm_val_2_py_val(j) for j in i.shape]
            if i in outs:
                emu_args.append(numpy.zeros(shape).astype(i.dtype))
                nd_args.append(tvm.nd.array(emu_args[-1], ctx))
                to_check.append((nd_args[-1], emu_args[-1]))
            else:
                emu_args.append(numpy.random.randn(*shape).astype(i.dtype))
                nd_args.append(tvm.nd.array(emu_args[-1], ctx))
        else:
            assert isinstance(i, tvm.expr.Var)
            emu_args.append(tvm_val_2_py_val(i))
            nd_args.append(emu_args[-1])

    func(*emu_args)

    lowerd_func = tvm.lower(func(*args), args)
    module = tvm.build(lowerd_func, target=target)
    assert module
    module(*nd_args)

    for nd, np in to_check:
41
        tvm.testing.assert_allclose(nd.asnumpy(), np, rtol=1e-5, atol=1e-5)
42

43

44 45
@script
def outer_product(n, m, a, b, c):
46
    """This is a simple outer product"""
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    for i in range(n):
        for j in range(m):
            c[i, j] = a[i] * b[j]

#Test global function
#Test bridge between frontend and backend
def test_outer_product():
    n = tvm.var('n')
    m = tvm.var('m')
    a = tvm.placeholder((n, ), name='a')
    b = tvm.placeholder((m, ), name='b')
    c = tvm.placeholder((n, m), name='c')
    ir = outer_product(n, m, a, b, c)
    #Check for i in (0, n)
    assert isinstance(ir, tvm.stmt.For)
    assert ir.loop_var.name == 'i'
    assert ir.min.value == 0
    assert ir.extent.name == 'n'
    ibody = ir.body
    assert isinstance(ibody, tvm.stmt.For)
    #Check for j in (0, m)
    assert ibody.loop_var.name == 'j'
    assert ibody.min.value == 0
    assert ibody.extent.name == 'm'
    #Check loop body
    jbody = ibody.body
    assert isinstance(jbody, tvm.stmt.Provide)
    assert jbody.func.name == 'c'
    assert len(jbody.args) == 2
    assert jbody.args[0].name == 'i'
    assert jbody.args[1].name == 'j'
    assert isinstance(jbody.value, tvm.expr.Mul)
    mul = jbody.value
    assert isinstance(mul.a, tvm.expr.Call)
    assert mul.a.name == 'a'
    assert mul.b.name == 'b'

    func = tvm.lower(ir, [n, m, a, b, c])
    func = tvm.build(func)

87
    run_and_check(outer_product, [n, m, a, b, c], [c], {n: 999, m: 1001})
88

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
    for key, _ in HYBRID_GLOBALS.items():
        assert key not in globals().keys()
        assert key not in outer_product.__globals__.keys()

#Test local function
#Test allocation of local variable
def test_fanout():
    @script
    def fanout(n, a, b):
        three = 3.0
        for i in range(a.shape[0] - 3):
            sigma = 0.0
            for j in range(3):
                sigma = sigma + a[i + j]
            sigma = sigma / three
            b[i] = sigma

    n = tvm.var('n')
107 108
    a = tvm.placeholder((n, ), 'float32', name='a')
    b = tvm.placeholder((n-3, ), 'float32', name='b')
109 110 111 112 113 114 115 116 117
    ir = fanout(n, a, b)

    #Check for i in (0, n-3)
    assert isinstance(ir, tvm.stmt.For)
    assert ir.loop_var.name == 'i'
    assert ir.min.value == 0
    assert tvm.ir_pass.Equal(ir.extent, n - 3)
    #Check loopbody
    ibody = ir.body
118 119 120 121 122 123
    assert isinstance(ibody, tvm.stmt.AttrStmt)
    abody = ibody.body
    assert isinstance(abody, tvm.stmt.Realize)
    assert abody.bounds[0].min.value == 0
    assert abody.bounds[0].extent.value == 1
    assert abody.func.name == 'sigma'
124
    #Check i loop body
125
    rbody = abody.body
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
    assert isinstance(rbody.first, tvm.stmt.Provide)
    assert rbody.first.func.name == 'sigma'
    assert len(rbody.first.args) == 1
    assert rbody.first.args[0].value == 0
    #Check fanout loop
    jloop = rbody.rest.first
    assert jloop.loop_var.name == 'j'
    assert jloop.min.value == 0
    assert jloop.extent.value == 3
    jbody = jloop.body
    assert isinstance(jbody, tvm.stmt.Provide)
    assert len(jbody.args) == 1
    assert jbody.args[0].value == 0
    assert jbody.func.name == 'sigma'
    assert isinstance(jbody.value, tvm.expr.Add)
    value = jbody.value
    assert isinstance(value.a, tvm.expr.Call)
    assert value.a.name == 'sigma'
    assert len(value.a.args) == 1
    assert value.a.args[0].value == 0
    assert value.b.name == 'a'
    assert len(value.b.args) == 1
    assert tvm.ir_pass.Equal(value.b.args[0], ir.loop_var + jloop.loop_var)
    divide= rbody.rest.rest.first
    assert isinstance(divide, tvm.stmt.Provide)
    assert len(divide.args) == 1
    assert divide.args[0].value == 0
    value = divide.value
    assert isinstance(value, tvm.expr.Mul)
    assert value.a.name == 'sigma'
    assert len(value.a.args) == 1
    assert value.a.args[0].value == 0
    assert abs(value.b.value - (1 / 3.0)) < 1e-5
    write = rbody.rest.rest.rest
    assert isinstance(write, tvm.stmt.Provide)
    assert write.func.name == 'b'
    assert write.value.name == 'sigma'
    assert len(write.value.args) == 1
    assert write.value.args[0].value == 0

166
    run_and_check(fanout, [n, a, b], [b], {n: 10})
167 168


169 170 171 172 173
@script
def failure():
    for i in range(1, 100):
        i = 0

174 175 176 177 178
def test_failure():
    try:
        tvm.hybrid.parse(failure, [])
    except IOError as err:
        assert sys.version_info[0] == 2
179 180
        print('[Warning] Case test_failure is skipped by Python2 because "%s"' % str(err))
    except Exception as err:
181 182 183 184 185
        assert str(err) == 'You CAN NEVER overwrite a loop variable!'


def test_looptype():
    @script
186 187
    def looptype(a, b, c):
        for i in parallel(8):
188
            a[i] = i
189 190 191 192 193 194 195 196 197
        for j in vectorize(8):
            b[j] = j
        for k in unroll(8):
            c[k] = k

    a = tvm.placeholder((8, ), name='a', dtype='int32')
    b = tvm.placeholder((8, ), name='b', dtype='int32')
    c = tvm.placeholder((8, ), name='c', dtype='int32')
    ir = looptype(a, b, c)
198 199 200 201 202 203 204
    iloop = ir.first
    jloop = ir.rest.first
    kloop = ir.rest.rest
    assert iloop.for_type == tvm.stmt.For.Parallel
    assert jloop.for_type == tvm.stmt.For.Vectorized
    assert kloop.for_type == tvm.stmt.For.Unrolled

205
    run_and_check(looptype, [a, b, c], [a, b, c])
206 207


208 209 210 211 212 213 214 215 216 217 218 219 220 221
def test_if():
    @script
    def if_then_else(a, b):
        for i in range(10):
            if i % 2 == 0:
                a[i] = -1
            else:
                a[i] = 1
        for i in unroll(10):
            b[i] = -1 if i % 2 == 0 else 1

    a = tvm.placeholder((10, ), dtype='int32', name='a')
    b = tvm.placeholder((10, ), dtype='int32', name='b')

222
    run_and_check(if_then_else, [a, b], [a, b])
223 224 225 226


def test_bind():
    if not tvm.gpu(0).exist:
227
        print('[Warning] No GPU found! Skip bind test!')
228 229 230 231 232 233 234 235 236 237
        return
    @script
    def vec_add(a, b, c):
        for tx in bind('threadIdx.x', 1000):
            c[tx] = b[tx] + c[tx]

    a = tvm.placeholder((1000, ), dtype='float32', name='a')
    b = tvm.placeholder((1000, ), dtype='float32', name='b')
    c = tvm.placeholder((1000, ), dtype='float32', name='c')

238
    run_and_check(vec_add, [a, b, c], [c], target='cuda')
239 240 241 242 243 244 245 246 247 248

def test_math_intrin():
    @script
    def intrin_real(a):
        a[0] = sqrt(a[0])
        a[1] = log(a[1])
        a[2] = exp(a[2])
        a[3] = sigmoid(a[3])
        a[4] = power(a[4], a[5])
        a[5] = tanh(a[5])
249 250
        a[6] = min(a[4], a[5])
        a[7] = max(a[5], a[6])
251

252 253 254
    a8 = tvm.placeholder((8, ), dtype='float32', name='a')
    ir = intrin_real(a8)
    func = tvm.build(tvm.lower(ir, [a8]))
255
    assert func
256
    a = numpy.arange(2, 10).astype('float32')
257 258 259
    tvm_a = tvm.ndarray.array(a)
    func(tvm_a)
    intrin_real(a)
260
    tvm.testing.assert_allclose(a, tvm_a.asnumpy(), rtol=1e-5)
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275

    @script
    def intrin_int(a):
        a[0] = popcount(a[0])

    a1 = tvm.placeholder((1, ), dtype='int32')
    ir = intrin_int(a1)
    func = tvm.build(tvm.lower(ir, [a1]))
    assert func
    a = numpy.array([1234567890]).astype('int32')
    tvm_a = tvm.ndarray.array(a)
    intrin_int(a)
    func(tvm_a)
    assert tvm_a.asnumpy()[0] == a[0]

276 277 278 279 280 281 282 283
def test_non_zero():
    @tvm.hybrid.script
    def blur(a, b):
        for i in range(2, 32):
            for j in range(2, 32):
                s = 0.0
                for di in range(3):
                    for dj in range(3):
284
                        s = s + a[i-di, j-dj]
285 286
                b[i-2, j-2] = s / 9.0
    try:
287 288 289 290 291 292
        a = tvm.placeholder((32, 32), 'float32', 'a')
        b = tvm.placeholder((30, 30), 'float32', 'b')
        run_and_check(blur, [a, b], [b])
    except IOError as err:
        assert sys.version_info[0] == 2
        print('[Warning] Case test_non_zero is skipped by Python2 because "%s"' % str(err))
293 294 295 296 297 298 299 300 301 302 303

    @tvm.hybrid.script
    def triangle(a, b, c):
        for i in range(10):
            for j in range(i, 10):
                c[i, j] = a[i] * b[j]

    a = tvm.placeholder((10, ), dtype='float32', name='a')
    b = tvm.placeholder((10, ), dtype='float32', name='b')
    c = tvm.placeholder((10, 10), dtype='float32', name='c')

304
    run_and_check(triangle, [a, b, c], [c])
305 306 307 308 309 310 311 312 313 314 315 316 317 318

def test_allocate():
    @tvm.hybrid.script
    def blur2d(a, b):
        for i in range(30):
            ha = allocate((3, 30), 'float32')
            for j in range(3):
                for k in range(30):
                    ha[j, k] = a[i+j, k] + a[i+j, k+1] + a[i+j, k+2]
            for j in range(30):
                b[i, j] = (ha[0, j] + ha[1, j] + ha[2, j]) / 9.0

    a = tvm.placeholder((32, 32), 'float32', 'a')
    b = tvm.placeholder((30, 30), 'float32', 'b')
319

320 321 322 323
    run_and_check(blur2d, [a, b], [b])

    if tvm.gpu().exist:
        @tvm.hybrid.script
324 325 326 327 328 329 330 331 332 333 334 335 336 337
        def share_vec_add(a, b, c):
            shared = allocate((256, ), 'float32', 'shared')
            for i in bind("threadIdx.x", 256):
                shared[i] = a[i]
            local = allocate((256, ), 'float32', 'local')
            for i in bind("threadIdx.x", 256):
                local[i] = b[i]
            for i in bind("threadIdx.x", 256):
                c[i] = shared[i] + local[i]

        a = tvm.placeholder((256, ), dtype='float32', name='a')
        b = tvm.placeholder((256, ), dtype='float32', name='b')
        c = tvm.placeholder((256, ), dtype='float32', name='c')
        run_and_check(share_vec_add, [a, b, c], [c], target='cuda')
338 339
    else:
        print('[Warning] No GPU found! Skip shared mem test!')
340

341 342 343 344 345 346 347 348 349

if __name__ == "__main__":
    test_outer_product()
    test_fanout()
    test_failure()
    test_looptype()
    test_if()
    test_bind()
    test_math_intrin()
350 351
    test_non_zero()
    test_allocate()
352