test_autotvm_common.py 2.06 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
"""Common utilities for testing autotvm"""
import time

import tvm
from tvm import autotvm
from tvm.autotvm import MeasureInput, MeasureResult

@autotvm.template
def matmul(N, L, M, dtype):
    A = tvm.placeholder((N, L), name='A', dtype=dtype)
    B = tvm.placeholder((L, M), name='B', dtype=dtype)

    k = tvm.reduce_axis((0, L), name='k')
    C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
    s = tvm.create_schedule(C.op)

    # schedule
    y, x = s[C].op.axis
    k = s[C].op.reduce_axis[0]

    ##### define space begin #####
    cfg = autotvm.get_config()
    cfg.define_split("tile_y", y, num_outputs=2)
    cfg.define_split("tile_x", x, num_outputs=2)
    ##### define space end #####

    # schedule according to config
    yo, yi = cfg["tile_y"].apply(s, C, y)
    xo, xi = cfg["tile_x"].apply(s, C, x)

    s[C].reorder(yo, xo, k, yi, xi)

    return s, [A, B, C]

35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
@autotvm.template
def bad_matmul(N, L, M, dtype):
    if 'bad_device' in tvm.target.current_target().keys:
        A = tvm.placeholder((N, L), name='A', dtype=dtype)
        B = tvm.placeholder((L, M), name='B', dtype=dtype)

        k = tvm.reduce_axis((0, L-1), name='k')
        C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i, k] * B[k, j], axis=k), name='C')
        s = tvm.create_schedule(C.op)

        # schedule
        y, x = s[C].op.axis
        cfg = autotvm.get_config()
        cfg.define_split("tile_y", y, num_outputs=2)
        cfg.define_split("tile_x", x, num_outputs=2)
        return s, [A, B, C]

    return matmul(N, L, M, dtype)

54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
def get_sample_task(n=128):
    """return a sample task for testing"""
    target = tvm.target.create("llvm")
    task = autotvm.task.create(matmul, args=(n, n, n, 'float32'), target=target)
    return task, target

def get_sample_records(n):
    """get sample records for testing"""
    tsk, target = get_sample_task()

    inps, ress = [], []
    for i in range(n):
        inps.append(MeasureInput(target, tsk, tsk.config_space.get(i)))
        ress.append(MeasureResult((i+1,), 0, i, time.time()))
    return list(zip(inps, ress))