measure_methods.py 19 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
# pylint: disable=consider-using-enumerate,invalid-name,too-many-function-args
"""
Functions that run on executor for measurement.
These functions are responsible for building tvm module, uploading it to
remote devices, recording the running time costs and checking the correctness of output
"""

import logging
import os
import time
from random import getrandbits
12
import threading
13 14 15

import numpy as np

16 17
from ... import ir_pass, build, build_config, nd, context, TVMError, register_func, \
    target as _target, rpc as _rpc
18
from ...contrib import nvcc, util, ndk
19 20 21 22 23

from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
from ..task.space import InstantiationError

24 25 26
from .measure import MeasureResult, MeasureErrorNo
from .local_executor import LocalExecutor

27
logger = logging.getLogger('autotvm')
28 29 30 31 32 33

class HashMismatchError(ValueError):
    """Raised when the code hash of a submitted config doesn't match that on the
       measure side """
    pass

34

35 36 37 38 39 40 41 42
def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
    """request a remote session

    Parameters
    ----------
    device_key: string
        device key of registered device in tracker
    tracker_addr: Tuple(string, int), optional
43 44 45
        The address of rpc tracker in (host, port) format.
        If is none, will use environment variable "TVM_TRACKER_HOST"
        and "TVM_TRACKER_PORT"
46
    priority: int, optional
47
        The priority of this request, larger is more prior
48
    timeout: float, optional
49
        The timeout of this session (units: seconds)
50 51 52 53 54 55 56

    Returns
    ------
    session: RPCSession
    """
    # connect to the tracker
    if tracker_addr:
57 58
        host = tracker_addr[0] or os.environ['TVM_TRACKER_HOST']
        port = tracker_addr[1] or int(os.environ['TVM_TRACKER_PORT'])
59 60 61 62
    else:
        host = os.environ['TVM_TRACKER_HOST']
        port = int(os.environ['TVM_TRACKER_PORT'])

63
    tracker = _rpc.connect_tracker(host, port)
64 65 66 67
    remote = tracker.request(device_key, priority=priority,
                             session_timeout=timeout)
    return remote

68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
def check_remote(target, device_key, tracker_addr=None, priority=2, timeout=10):
    """
    Check the availability of a remote device

    Parameters
    ----------
    target: Target
        The wanted compilation target
    device_key: string
        device key of registered device in tracker
    tracker_addr: Tuple(string, int), optional
        The address of rpc tracker in (host, port) format.
        If is none, will use environment variable "TVM_TRACKER_HOST"
        and "TVM_TRACKER_PORT"
    priority: int, optional
        The priority of this request, larger is more prior
    timeout: float, optional
        The timeout of this check (units: seconds).
86
        If time is out, a RuntimeError will be raised.
87 88 89 90 91 92 93 94
    """
    def _check():
        remote = request_remote(device_key, tracker_addr, priority)
        remote.context(str(target))
    t = threading.Thread(target=_check,)
    t.start()
    t.join(timeout)
    return not t.is_alive()
95

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
def create_measure_batch(task, option):
    """Get a standard measure_batch function.

    Parameters
    ----------
    task: tvm.autotvm.task.Task
        The tuning task
    option: dict
        The option for measuring generated code.
        You should use the return value of function :any:`measure_option` for this argument.

    Returns
    -------
    measure_batch: callable
        a callback function to measure a batch of configs
    """
    from ..database import filter_inputs

    measure_func = option['measure_func']
    number, repeat = option['number'], option['repeat']
116
    timeout, n_parallel, do_fork = option['timeout'], option['n_parallel'], option['do_fork']
117 118 119 120 121 122 123 124 125 126
    build_func = option['build_func']
    check_correctness = option['check_correctness']
    replay_db = option['replay_db']

    executor = LocalExecutor(timeout=timeout, do_fork=do_fork)

    # convert convenient string to function object
    attach_objects = None
    if measure_func == 'local':
        # start temporary rpc tracker and rpc server for the user
127 128 129 130
        from ...rpc.tracker import Tracker
        from ...rpc.server import Server

        tracker = Tracker('localhost', port=9000, port_end=10000, silent=True)
131
        device_key = '$local$device$%d' % tracker.port
132 133 134 135
        server = Server('localhost', port=9000, port_end=10000,
                        key=device_key,
                        use_popen=True, silent=True,
                        tracker_addr=(tracker.host, tracker.port))
136

137
        measure_func = rpc(device_key, tracker.host, tracker.port)
138 139 140 141 142 143 144 145 146
        attach_objects = (server, tracker)

    build_kwargs = {}
    if build_func == 'default':
        build_func = default_build_func
    if build_func == 'ndk':
        build_func = default_build_func
        build_kwargs['use_ndk'] = True

147 148 149 150 151 152 153 154 155 156 157
    # check the availability of remote devices
    if hasattr(measure_func, 'rpc_info'):
        rpc_info = measure_func.rpc_info
        if check_remote(task.target, rpc_info['key'], (rpc_info['host'], rpc_info['port'])):
            logger.info("Get devices for measurement successfully!")
        else:
            raise RuntimeError("Cannot get remote devices from the tracker. "
                               "Please check the status of tracker by "
                               "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
                               "and make sure you have free devices on the queue status.")

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
    # add device info of cuda and opencl target
    if ('cuda' in task.target.keys or 'opencl' in task.target.keys) \
            and hasattr(measure_func, 'rpc_info'):
        rpc_info = measure_func.rpc_info
        add_gpu_target_info(task.target, rpc_info["key"], (rpc_info["host"], rpc_info["port"]),
                            build_kwargs)

    if check_correctness:
        # use llvm cpu to generate a reference input/output
        # this option works for tuning topi, but might not work for you custom op
        with _target.create("llvm"):
            s, arg_bufs = task.instantiate(task.config_space.get(0))
        ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
                     for x in arg_bufs]
        func = build(s, arg_bufs, "llvm")
        tvm_buf = [nd.array(x) for x in ref_input]
        func(*tvm_buf)
        ref_output = [x.asnumpy() for x in tvm_buf]
    else:
        ref_input = ref_output = None

    def measure_batch(measure_inputs):
        """measure the time cost for a batch of configs in real machines"""
        if replay_db is not None:
            partial_results, measure_inputs = \
                filter_inputs(replay_db, measure_inputs, retry=False)

        # launch measure jobs in parallel
        pack_size = getattr(measure_func, "pack_size", 1)  # measure `pack_size` inputs in one job
        futures = []
        for i in range(0, len(measure_inputs), pack_size):
            input_pack = measure_inputs[i:i + pack_size]
            ret = executor.submit(
                measure_func,
                input_pack,
                build_func,
                build_kwargs,
                number,
                repeat,
                ref_input,
                ref_output)
            futures.append(ret)

        # transform results
        results = []
        for future in futures:
            result = future.get()
            if isinstance(result, Exception):
                tstamp = time.time()
                results.extend([MeasureResult((result,), MeasureErrorNo.FLEET_ERROR,
                                              timeout, tstamp)] * pack_size)
            else:
                results.extend(result)

        if replay_db is not None:
            result_idx = 0
            for i in range(len(partial_results)):
                if partial_results[i] is None:
                    partial_results[i] = results[result_idx]
                    result_idx += 1
            return partial_results
        return results

221
    measure_batch.n_parallel = n_parallel
222 223 224 225 226
    # attach server and tracker object to avoid them of being garbage-collected
    measure_batch.attach_objects = attach_objects
    return measure_batch


227 228 229 230 231 232
def rpc(key,
        host=None,
        port=None,
        priority=1,
        session_timeout=60,
        pack_size=1):
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
    """
    Create a standard measure_func which uses RPC Tracker for measurement.
    This measure_func will request a device from the RPC Tracker and
    upload the built binary library to that device for measurement.

    Parameters
    ----------
    key: str
        The registered key of the device in tracker. The tuner will request devices for
        measurement by this key.
    host: str, optional
        The hostname of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_HOST"
    port: int, optional
        The port of RPC Tracker. If not set, will use environment variable "TVM_TRACKER_PORT"
    priority: int, optional
        Priority of this task, used by scheduler in tracker
    session_timeout: int, optional
        Timeout of rpc session
    pack_size: int, optional
        The number of configs measure in one RPC session.
        Usually this can be set to 1. If your device has high overhead to establish a
        rpc connection, set this higher.
    """
    def fmeasure(input_pack, build_func, build_kwargs, number, repeat, ref_input, ref_output):
        """Do measurement for a list of inputs inside a same RPC session.

        Parameters
        ----------
        input_pack: List of MeasureInput
            The inputs of measurement
        build_func: callable
            Function for building the code. see :any:`default_build_func` for example
        build_kwargs: dict
            Extra arguments for build_func
        number : int, optional
            Number of times to do the measurement for average
        repeat : int, optional
            Number of times to repeat the measurement.
            In total, the generated code will be run (1 + number x repeat) times,
            where the first one is warm up. The returned result contains `repeat` costs,
            each of which is the average of `number` test run.
        ref_input: List of numpy array
            Reference input for correctness check
        ref_output: List of numpy array
            Reference output for correctness check

        Returns
        -------
        results: List of MeasureResult
            The results for input_pack
        """
284
        remote_args = (key, (host, port), priority, session_timeout)
285 286 287

        res = _measure_common(input_pack, build_func, build_kwargs, number, repeat,
                              ref_input, ref_output,
288
                              remote_args)
289 290 291 292 293 294 295 296
        return res

    fmeasure.pack_size = pack_size
    fmeasure.rpc_info = {"key": key, "host": host, "port": port}
    return fmeasure


def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
297
                    ref_input=None, ref_output=None, remote_args=None):
298 299 300
    """Measure the time cost for a pack of inputs.

    (Note: A pack is a list of inputs which will be measured inside a same RPC session)
301 302 303 304 305

    Parameters
    ----------
    input_pack : list of MeasureInput
        The inputs we need to evaluate
306 307 308 309 310 311 312 313 314 315 316 317
    build_func : function takes MeasureInput returns tuple of (time_func, ctx, args)
        The build function used to build each input.
    build_kwargs: Dict
        The extra keyword arguments to build_func
    number : int, optional
        Number of times to do the measurement for average
    repeat : int, optional
        Number of times to repeat the measurement.
        In total, the generated code will be run (1 + number x repeat) times,
        where the first one is warm up. The returned result contains `repeat` costs,
        each of which is the average of `number` test run.
    ref_input: Array of np.ndarray, optional
318
        Reference input for checking correctness
319
    ref_output: Array of np.ndarray, optional
320
        Reference output for checking correctness
321 322
    remote_args: Tuple, optional
        The arguments to request_remote. If is not None, will use remote rpc devices.
323 324 325

    Returns
    -------
326 327
    res_pack : Array of MeasureResult
        The list of results of measurement.
328 329
    """
    res_pack = []
330 331
    tmp_dir = util.tempdir() if remote_args else None
    assert len(input_pack) == 1, "Only supports input_pack == 1 for now"
332

333 334
    for inp in input_pack:
        tic = time.time()
335 336

        # build function
337
        try:
338
            func, arg_bufs, filename = build_func(inp, tmp_dir, **build_kwargs)
339 340 341 342 343 344 345 346 347 348
        except TVMError as exc:
            tstamp = time.time()
            msg = str(exc)
            if "Stack trace returned" in msg:
                msg = msg[:msg.index("Stack trace returned")]
            if "InstantiationError" in msg:
                try:
                    msg = msg.split('\n')[-2].split(": ")[1]
                except Exception:  # pylint: disable=broad-except
                    pass
349 350 351
                res_pack.append(MeasureResult((InstantiationError(msg),),
                                              MeasureErrorNo.INSTANTIATION_ERROR,
                                              tstamp - tic, tstamp))
352 353 354 355 356 357 358
            else:
                res_pack.append(MeasureResult((RuntimeError(msg),),
                                              MeasureErrorNo.COMPILE_HOST,
                                              tstamp - tic, tstamp))
            continue
        except InstantiationError as e:
            tstamp = time.time()
359
            res_pack.append(MeasureResult((InstantiationError(str(e)),),
360 361 362 363 364 365 366
                                          MeasureErrorNo.INSTANTIATION_ERROR,
                                          tstamp - tic, tstamp))
            continue

        # measure time
        errno = MeasureErrorNo.NO_ERROR
        try:
367 368 369 370 371 372 373 374 375 376 377 378 379 380
            # upload built module
            if remote_args:
                remote = request_remote(*remote_args)
                remote.upload(tmp_dir.relpath(filename))
                func = remote.load_module(filename)
                ctx = remote.context(str(inp.target), 0)
                time_f = func.time_evaluator(
                    func.entry_name, ctx, number=number, repeat=repeat)
            else:
                ctx = context(str(inp.target), 0)
                time_f = func.time_evaluator(
                    func.entry_name, ctx, number=number, repeat=repeat)

            # set input
381
            if ref_input:
382
                args = [nd.array(x, ctx=ctx) for x in ref_input]
383
            else:
384 385
                args = [nd.empty(get_const_tuple(x.shape), dtype=x.dtype, ctx=ctx)
                        for x in arg_bufs]
386

387 388 389 390 391
            costs = time_f(*args).results
            if len(costs) > 2:  # remove largest and smallest value to reduce variance
                costs = list(costs)
                costs.sort()
                costs = tuple(costs[1:-1])
392 393

            # check correctness of output
394 395 396
            if ref_output:
                for expected, real in zip(ref_output, args):
                    if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
397
                        logger.warning("Wrong Answer!")
398 399 400 401 402
                        errno = MeasureErrorNo.WRONG_ANSWER
        except TVMError as exc:
            msg = str(exc)
            if "Stack trace returned" in msg:
                msg = msg[:msg.index("Stack trace returned")]
403 404
            if "CUDA Source" in msg:
                msg = msg[:msg.index("CUDA Source")]
405 406 407 408 409 410
            costs = (RuntimeError(msg),)
            errno = MeasureErrorNo.RUNTIME_DEVICE
        tstamp = time.time()
        res_pack.append(MeasureResult(costs, errno, tstamp - tic, tstamp))
    return res_pack

411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435

def default_build_func(inp, tmp_dir=None, **kwargs):
    """Build function module. Exception will be raised when any error occurs

    Parameters
    ----------
    inp: MeasureInput
       The input of this measurement
    tmp_dir: tvm.contrib.util.TempDirectory, optional
       The temporary directory for exporting built binary library.
       If is not None (in RPC mode), the library in this directory will be uploaded to
       remote devices.
    kwargs: Dict, optional
        Other extra arguments

    Returns
    -------
    func: Function
        TVM built function. Typically this is the return value of tvm.build.
    args: Array of Buffer or Tensor
        The argument list for the function. Typically this is the second argument of tvm.build.
    filename: str
        The filename of the output build library
    """
    # build function
436 437
    with inp.target:
        s, args = inp.task.instantiate(inp.config)
438 439

        # check invalidity of template and code hash consistency
440 441 442 443 444 445 446
        if not inp.config.valid():
            raise InstantiationError(inp.config.errors)
        code_hash = getattr(s, 'code_hash', None)
        if inp.config.code_hash != code_hash:
            raise HashMismatchError('got {0:s}, expected {1:s}'
                                    .format(str(inp.config.code_hash), str(code_hash)))

447 448 449
        opts = {}
        if "check_gpu" in kwargs:  # Add verify pass to filter out invalid configs in advance.
            opts["add_lower_pass"] = [(2, gpu_verify_pass(**kwargs['check_gpu']))]
450 451 452 453 454 455
        if 'cuda_arch' in kwargs:
            set_cuda_target_arch(kwargs['cuda_arch'])

        with build_config(**opts):
            func = build(s, args, target_host=inp.task.target_host)

456 457 458 459 460 461 462 463 464 465
    # export library to temp directory
    if tmp_dir:
        if kwargs.get('use_ndk', False):  # for Android NDK
            filename = "tmp_func_%0x.so" % getrandbits(64)
            func.export_library(tmp_dir.relpath(filename), ndk.create_shared)
        else:
            filename = "tmp_func_%0x.tar" % getrandbits(64)
            func.export_library(tmp_dir.relpath(filename))
    else:
        filename = None
466

467
    return func, args, filename
468 469


470 471 472 473 474 475 476 477 478 479 480 481 482
def add_gpu_target_info(target, device_key, rpc_tracker_addr, kwargs):
    """Add device info for gpu target.
    The info will be used to check the validity of generated code."""
    remote = request_remote(device_key, rpc_tracker_addr)
    ctx = remote.context(str(target), 0)
    max_dims = ctx.max_thread_dimensions
    kwargs['check_gpu'] = {
        'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
        'max_threads_per_block': ctx.max_threads_per_block,
        'max_thread_x': max_dims[0],
        'max_thread_y': max_dims[1],
        'max_thread_z': max_dims[2],
    }
483

484 485
    if 'cuda' in target.keys:
        kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))
486

487 488 489
def set_cuda_target_arch(arch):
    """set target architecture of nvcc compiler"""
    AutotvmGlobalScope.current.cuda_target_arch = arch
490 491


492 493 494 495 496
@register_func
def tvm_callback_cuda_compile(code):
    """use nvcc to generate ptx code for better optimization"""
    ptx = nvcc.compile_cuda(code, target="ptx", arch=AutotvmGlobalScope.current.cuda_target_arch)
    return ptx
497 498 499


def gpu_verify_pass(**kwargs):
500 501
    """Verify the validity of a gpu kernel.
    This pass will check memory usage and number of threads per block.
502 503 504 505 506 507 508
    """
    def verify_pass(stmt):
        valid = ir_pass.VerifyGPUCode(stmt, kwargs)
        if not valid:
            raise InstantiationError("Skipped because of invalid gpu kernel")
        return stmt
    return verify_pass