peak.py 10.2 KB
Newer Older
1 2 3 4 5
# pylint: disable=invalid-name
"""measure bandwidth and compute peak"""

import logging
import tvm
6 7
from . import util
from .. import rpc
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

def _convert_to_remote(func, remote):
    """ convert module function to remote rpc function"""
    temp = util.tempdir()
    path_dso = temp.relpath("tmp_func.tar")
    func.export_library(path_dso)

    remote.upload(path_dso)
    func = remote.load_module("tmp_func.tar")
    return func

def measure_bandwidth_sum(total_item, item_per_thread, stride,
                          base_type, bits, lanes,
                          target, target_host, remote, ctx, n_times):
    """ measure memory bandwidth of gpu by product reduction for a given type

    The IR for measurement is

    for each thread
        for i in 1..num_per_thread:
            y[global_id] = y[global_id] * x[base + i * stride]

    Parameters
    ----------
    total_item: int
        number of elements in input array
    item_per_thread: int
        number of elements each thread accumulates
    stride: int
        stride in memory access
    base_type: str
        can be "int", "float"
    bits: int
        can be 16, 32
    lanes: int
       lane of the vector type, can be 1, 2, 4, 8, 16
    target: :any:`tvm.target.Target`
        the target and option of the compilation.
    target_host : str or :any:`tvm.target.Target`
        host compilation target
    ctx: TVMcontext
        the context of array
50
    remote: tvm.rpc.RPCSession
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
        remote rpc session
    n_times: int
        number of runs for taking mean

    Returns
    -------
    GBPS: float
         gigabyte per second
    """
    n, m = total_item, item_per_thread
    n //= lanes

    base_type = str(base_type) + str(bits)
    dtype = base_type if lanes == 1 else base_type + "x" + str(lanes)

    k = tvm.reduce_axis((0, m), name="k")

    x = tvm.placeholder((n,), dtype=dtype, name="x")
    op = tvm.comm_reducer(lambda x, y: x*y, lambda t: tvm.const(1, dtype=t), name="sum")
    y = tvm.compute((n // m,),
                    lambda i: op(x[i // stride * stride * m + i % stride + k * stride], axis=k))
    s = tvm.create_schedule(y.op)

    yo, yi = s[y].split(y.op.axis[0], target.max_num_threads)
    s[y].bind(yo, tvm.thread_axis("blockIdx.x"))
    s[y].bind(yi, tvm.thread_axis("threadIdx.x"))
    s[y].unroll(k)

    try:
        func = tvm.build(s, [x, y], target, target_host=target_host)

        x = tvm.nd.empty((n,), dtype=dtype, ctx=ctx)
        y = tvm.nd.empty((n // m,), dtype=dtype, ctx=ctx)

        func = _convert_to_remote(func, remote)
        time_f = func.time_evaluator(func.entry_name, ctx, number=n_times)
        time = time_f(x, y).mean
    except tvm._ffi.base.TVMError:
        # build error (occur when device does not support half)
        return -1

    return 1.0 * (total_item * bits / 8) / 1e9 / time

def measure_bandwidth_all_types(total_item, item_per_thread, n_times,
                                target, target_host, remote, ctx, verbose=True):
    """ measure memory bandwidth for all types

    Parameters
    ----------
    total_item: int
        number of elements in input array
    item_per_thread: int
        number of elements each thread accmulates
    n_times: int
        number of runs for averaging
    target: :any:`tvm.target.Target`
        the target and option of the compilation.
    target_host : str or :any:`tvm.target.Target`
        host compilation target
110
    remote: tvm.rpc.RPCSession
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
        remote rpc session
    ctx: TVMcontext
        the context of array
    verbose: bool
        whether outputs immediate result

    Returns
    -------
    result: list
        a list of (type_name, GBPS) pairs
    """
    max_threads = target.max_num_threads

    result = []
    for base_type in ["float"]:
        for bits in [32]:
            for lanes in [1, 2, 4, 8, 16]:
                max_speed = -1e9
                # try different strides
                for stride in [max_threads, total_item // (lanes * item_per_thread)]:
                    speed = measure_bandwidth_sum(total_item, item_per_thread, stride,
                                                  base_type, bits, lanes, target,
                                                  target_host, remote, ctx, n_times)
                    max_speed = max(max_speed, speed)
                type_name = base_type + str(bits)
                result.append(["%sx%d" % (type_name, lanes), max_speed])
                if verbose:
                    logging.info("\t%-10s %.2f GBPS", result[-1][0], result[-1][1])
    return result

def measure_compute_mad(total_item, item_per_thread, base_type, bits, lanes,
                        target, target_host, remote, ctx, n_times):
    """ measure peak compute speed by computing mad for a type

    The IR for measurement is

    for each thread
        for i in 1..item_per_thread
            x = mad(x, x, y)
            y = mad(y, y, x)

    Parameters
    ----------
    total_item: int
        number of elements in input array
    item_per_thread: int
        number of operations each thread does
    base_type: str
        can be "int", "float"
    bits: int
        can be 16, 32
    lanes: int
       lane of the vector type, can be 1, 2, 4, 8, 16
    target: :any:`tvm.target.Target`
        the target and option of the compilation.
    target_host : str or :any:`tvm.target.Target`
        host compilation target
168
    remote: tvm.rpc.RPCSession
169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
        if it is not None, use remote rpc session
    ctx: TVMcontext
        the context of array
    n_times: int
        number of runs for taking mean

    Returns
    -------
    GOPS: float
         giga operation per second
    """

    n = total_item

    if bits >= 64 or lanes >= 16:
        n //= 2

    max_threads = target.max_num_threads

    base_type = str(base_type) + str(bits)
    dtype = base_type if lanes == 1 else base_type + "x" + str(lanes)

    def extern(ins, outs):
        # pylint: disable=unused-argument
        """construct measurement function by building IR directly"""
        ib = tvm.ir_builder.create()

        bx = tvm.thread_axis("blockIdx.x")
        tx = tvm.thread_axis("threadIdx.x")

        ib.scope_attr(bx, "thread_extent", n // max_threads)
        ib.scope_attr(tx, "thread_extent", max_threads)

        idx = bx.var * max_threads + tx.var

        a = ib.allocate(dtype, (1), name='a', scope='local')
        b = ib.allocate(dtype, (1), name='b', scope='local')

        a[0] = outs[0].vload(idx, dtype)
        b[0] = outs[0].vload(idx, dtype)

        if base_type.find('float') != -1:
            mad_func = lambda x, y: (x * x + y)
        else:
            mad_func = lambda x, y: y * y + x

        for _ in range(item_per_thread // 4 // lanes):
            a[0] = mad_func(a[0], b[0])
            b[0] = mad_func(b[0], a[0])

        ib.emit(outs[0].vstore(idx, b[0]))
        return ib.get()

    y = tvm.extern((n,), [], extern, name="y", dtype=dtype)
    s = tvm.create_schedule(y.op)

    try:
        func = tvm.build(s, [y], target, target_host=target_host)
        func = _convert_to_remote(func, remote)
        time_f = func.time_evaluator(func.entry_name, ctx, number=n_times)
        y = tvm.nd.empty((n,), dtype=dtype, ctx=ctx)
        time = time_f(y).mean
    except tvm._ffi.base.TVMError:
        # build error (occur when device does not support half)
        return -1

    return 1.0 * (n * item_per_thread) / 1e9 / time

def measure_compute_all_types(total_item, item_per_thread, n_times,
                              target, target_host, remote, ctx, verbose=True):
    """ measure peak flops for all types

    Parameters
    ----------
    total_item: int
        number of elements in input array
    item_per_thread: int
        number of elements each thread accmulates
    n_times: int
        number of runs for averaging
    target: :any:`tvm.target.Target`
        the target and option of the compilation.
    target_host : str or :any:`tvm.target.Target`
        host compilation target
253
    remote: tvm.rpc.RPCSession
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325
        remote rpc session
    ctx: TVMcontext
        the context of array
    verbose: bool
        whether outputs immediate result

    Returns
    -------
    result: list
        a list of (type_name, GFLOPS/GIOPS) pairs
    """
    result = []
    for base_type in ["float", "int"]:
        for bits in [16, 32, 64]:
            for lanes in [1, 2, 4, 8, 16]:
                if base_type == 'int' and bits != 32:  # only measure int32
                    continue

                max_speed = -1e9
                for per_thread in [item_per_thread//2, item_per_thread, item_per_thread*2]:
                    speed = measure_compute_mad(total_item, per_thread,
                                                base_type, bits, lanes, target,
                                                target_host, remote, ctx, n_times)
                    max_speed = max(max_speed, speed)
                type_name = base_type + str(bits)
                result.append(["%sx%d" % (type_name, lanes), max_speed])

                unit = "GFLOPS" if base_type == "float" else "GIOPS"

                if verbose:
                    logging.info("\t%-10s %.2f %s", result[-1][0], result[-1][1], unit)

    return result


def measure_peak_all(target, target_host, host, port):
    """measure memory bandwidth and peak compute for gpu devices

    Parameters
    ----------
    target: str or :any:`tvm.target.Target`
    target_host: str
    host: str
    port: int
    """

    target = tvm.target.create(target)
    remote = rpc.connect(host, port)
    n_times = 20

    bandwidth_total_item = 1 << 25
    bandwidth_item_per_thread = 32

    compute_total_item = 1 << 21
    compute_item_per_thread = 4096

    if str(target).startswith("opencl"):
        ctx = remote.cl()
    elif str(target).startswith("cuda"):
        ctx = remote.gpu()
    elif str(target).startswith("metal"):
        ctx = remote.metal()
    else:
        raise RuntimeError("Unsupported target")

    logging.info("========== measure memory bandwidth ==========")
    measure_bandwidth_all_types(bandwidth_total_item, bandwidth_item_per_thread,
                                n_times, target, target_host, remote, ctx)

    logging.info("========== measure peak compute ==========")
    measure_compute_all_types(compute_total_item, compute_item_per_thread,
                              n_times, target, target_host, remote, ctx)