measure_methods.py 24.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17
# pylint: disable=invalid-name,too-many-function-args,too-many-nested-blocks
18 19
"""
Functions that run on executor for measurement.
20 21 22

These functions are responsible for building the tvm module, uploading it to
remote devices, recording the running time costs, and checking the correctness of the output.
23 24 25
"""

import logging
26
import shutil
27
import os
28
import threading
29 30
import time
from random import getrandbits
31 32
from collections import namedtuple
import tempfile
33 34 35

import numpy as np

36 37
from ... import ir_pass, build, build_config, nd, TVMError, register_func, \
    rpc as _rpc, target as _target
38
from ...contrib import nvcc, ndk, tar
39 40 41 42 43

from ..util import get_const_tuple
from ..env import AutotvmGlobalScope
from ..task.space import InstantiationError

44
from .measure import MeasureResult, MeasureErrorNo, Builder, Runner
45 46
from .local_executor import LocalExecutor

47
logger = logging.getLogger('autotvm')
48

49 50 51
class BuildResult(namedtuple("BuildResult", ('filename', 'arg_info', 'error', 'time_cost'))):
    """
    Stores all the necessary inputs for a measurement.
52

53 54 55 56 57 58 59 60 61 62 63
    Parameters
    ----------
    filename : str
        The filename of generated library
    arg_info : Tuple
        The shape and dtype information of tvm tensor arguments
    error : Exception
        The error happens during compilation.
    time_cost : float
        The time cost of building
    """
64

65 66
class LocalBuilder(Builder):
    """Run compilation on local machine
67 68 69

    Parameters
    ----------
70 71 72 73 74 75 76
    timeout: float
        The timeout of a compilation
    n_parallel: int
        The number of tasks run in parallel. "None" will use all cpu cores
    build_func: callable or str
        If is 'default', use default build function
        If is 'ndk', use function for android ndk
77
        If is callable, use it as custom build function, expect lib_format field.
78
    """
79 80 81 82 83
    def __init__(self, timeout=10, n_parallel=None, build_func='default'):
        super(LocalBuilder, self).__init__(timeout, n_parallel)

        if isinstance(build_func, str):
            if build_func == 'default':
84
                build_func = tar.tar
85
            elif build_func == 'ndk':
86
                build_func = ndk.create_shared
87 88
            else:
                raise ValueError("Invalid build_func" + build_func)
89
        self.build_func = _wrap_build_func(build_func)
90
        self.executor = LocalExecutor(timeout=timeout)
91
        self.tmp_dir = tempfile.mkdtemp()
92

93 94 95
    def build(self, measure_inputs):
        results = []

Cody Yu committed
96
        shutil.rmtree(self.tmp_dir, ignore_errors=True)
97 98
        self.tmp_dir = tempfile.mkdtemp()

99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
        for i in range(0, len(measure_inputs), self.n_parallel):
            futures = []
            for inp in measure_inputs[i:i + self.n_parallel]:
                ret = self.executor.submit(self.build_func,
                                           inp,
                                           self.tmp_dir,
                                           **self.build_kwargs)
                futures.append(ret)

            for future in futures:
                res = future.get()

                if isinstance(res, Exception):
                    # timeout or fleet error, return MeasureResult directly
                    results.append(MeasureResult((res,), MeasureErrorNo.BUILD_TIMEOUT,
                                                 self.timeout, time.time()))
                elif res.error is not None:
116
                    # instantiation error
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
                    if isinstance(res.error, InstantiationError):
                        results.append(MeasureResult((res.error,),
                                                     MeasureErrorNo.INSTANTIATION_ERROR,
                                                     res.time_cost, time.time()))
                    else:
                        if "InstantiationError" in str(res.error):
                            msg = str(res.error)
                            try:
                                msg = msg.split('\n')[-2].split(": ")[1]
                            except Exception:  # pylint: disable=broad-except
                                pass
                            results.append(MeasureResult((InstantiationError(msg),),
                                                         MeasureErrorNo.INSTANTIATION_ERROR,
                                                         res.time_cost, time.time()))
                        else:  # tvm error
                            results.append(MeasureResult((res.error,),
                                                         MeasureErrorNo.COMPILE_HOST,
                                                         res.time_cost, time.time()))
                else:
                    # return BuildResult
                    results.append(res)

        return results


class RPCRunner(Runner):
    """Run generated code on remove devices.
    This function will ask a RPC Tracker to get device for measurement.
145 146 147

    Parameters
    ----------
148 149 150 151 152 153 154 155 156 157
    timeout: float
        The timeout of a compilation
    n_parallel: int
        The number of tasks run in parallel. "None" will use all cpu cores
    key: str
        The key of the device registered in the tracker
    host: str
        The host address of RPC Tracker
    port: int
        The port of RPC Tracker
158 159 160
    number: int
        The number of times to run the generated code for taking average.
        We call these runs as one `repeat` of measurement.
161
    repeat : int, optional
162
        The number of times to repeat the measurement.
163
        In total, the generated code will be run (1 + number x repeat) times,
164 165 166 167 168 169 170 171 172 173
        where the first "1" is warm up and will be discarded.
        The returned result contains `repeat` costs,
        each of which is an average of `number` costs.
    min_repeat_ms: int, optional
        The minimum duration of one `repeat` in milliseconds.
        By default, one `repeat` contains `number` runs. If this parameter is set,
        the parameters `number` will be dynamically adjusted to meet the
        minimum duration requirement of one `repeat`.
        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
        will be automatically increased.
174 175 176 177 178 179
    cooldown_interval: float, optional
        The cool down interval between two measurements.
    check_correctness: bool, optional
        Whether check correctness after measurement. This will use llvm cpu target to
        call your template and get the reference output.
        This can work for TOPI templates, but may not work for your custom template.
180
    """
181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    def __init__(self,
                 key, host, port, priority=1,
                 timeout=10, n_parallel=None,
                 number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
                 check_correctness=False):
        super(RPCRunner, self).__init__(timeout, n_parallel)

        self.key = key
        self.host = host
        self.port = port
        self.priority = priority
        self.timeout = timeout

        self.number = number
        self.repeat = repeat
        self.min_repeat_ms = min_repeat_ms

        self.ref_input = None
        self.ref_output = None
        self.check_correctness = check_correctness
        self.cooldown_interval = cooldown_interval

        self.executor = LocalExecutor()

    def set_task(self, task):
        self.task = task

        if check_remote(task.target, self.key, self.host, self.port):
            logger.info("Get devices for measurement successfully!")
        else:
            raise RuntimeError("Cannot get remote devices from the tracker. "
                               "Please check the status of tracker by "
                               "'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
                               "and make sure you have free devices on the queue status.")
215

216 217 218 219 220 221 222 223 224 225 226 227 228 229
        if self.check_correctness:
            # use llvm cpu to generate a reference input/output
            # this option works for tuning topi, but might not work for you custom op
            with _target.create("llvm"):
                s, arg_bufs = task.instantiate(task.config_space.get(0))
            self.ref_input = [np.random.uniform(size=get_const_tuple(x.shape)).astype(x.dtype)
                              for x in arg_bufs]
            func = build(s, arg_bufs, "llvm")
            tvm_buf = [nd.array(x) for x in self.ref_input]
            func(*tvm_buf)
            self.ref_output = [x.asnumpy() for x in tvm_buf]

    def get_build_kwargs(self):
        kwargs = {}
230 231
        if 'cuda' in self.task.target.keys or 'opencl' in self.task.target.keys or \
           'rocm' in self.task.target.keys:
232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
            remote = request_remote(self.key, self.host, self.port)
            ctx = remote.context(str(self.task.target), 0)
            max_dims = ctx.max_thread_dimensions
            kwargs['check_gpu'] = {
                'max_shared_memory_per_block': ctx.max_shared_memory_per_block,
                'max_threads_per_block': ctx.max_threads_per_block,
                'max_thread_x': max_dims[0],
                'max_thread_y': max_dims[1],
                'max_thread_z': max_dims[2],
            }

            if 'cuda' in self.task.target.keys:
                kwargs["cuda_arch"] = "sm_" + "".join(ctx.compute_version.split('.'))

        return kwargs

    def run(self, measure_inputs, build_results):
        results = []
        remote_args = (self.key, self.host, self.port, self.priority, self.timeout)

        for i in range(0, len(measure_inputs), self.n_parallel):
            futures = []
            for measure_inp, build_res in zip(measure_inputs[i:i+self.n_parallel],
                                              build_results[i:i+self.n_parallel]):
                ret = self.executor.submit(run_through_rpc,
                                           measure_inp,
                                           build_res,
259
                                           self.number,
260
                                           self.repeat,
261
                                           self.min_repeat_ms,
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
                                           self.cooldown_interval,
                                           remote_args,
                                           self.ref_input,
                                           self.ref_output)
                futures.append(ret)

            for future in futures:
                res = future.get()
                if isinstance(res, Exception):   # executor error or timeout
                    results.append(MeasureResult((str(res),), MeasureErrorNo.RUN_TIMEOUT,
                                                 self.timeout, time.time()))
                else:
                    results.append(res)

        return results

class LocalRunner(RPCRunner):
    """Run generated code on local devices.
280 281 282

    Parameters
    ----------
283 284
    timeout: float
        The timeout of a compilation
285 286 287
    number: int
        The number of times to run the generated code for taking average.
        We call these runs as one `repeat` of measurement.
288
    repeat : int, optional
289
        The number of times to repeat the measurement.
290
        In total, the generated code will be run (1 + number x repeat) times,
291 292 293 294 295 296 297 298 299 300
        where the first one is warm up and will be discarded.
        The returned result contains `repeat` costs,
        each of which is an average of `number` costs.
    min_repeat_ms: int, optional
        The minimum duration of one `repeat` in milliseconds.
        By default, one `repeat` contains `number` runs. If this parameter is set,
        the parameters `number` will be dynamically adjusted to meet the
        minimum duration requirement of one `repeat`.
        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
        will be automatically increased.
301 302 303 304 305 306 307 308 309 310 311
    cooldown_interval: float, optional
        The cool down interval between two measurements.
    check_correctness: bool, optional
        Whether check correctness after measurement. This will use llvm cpu target to
        call your template and get the reference output.
        This can work for TOPI templates, but may not work for your custom template.

    Note
    ----
    This is a "fake" local mode. We start a silent rpc tracker and rpc server
    for the user. In this way we reuse timeout/isolation mechanism in RPC infrastructure.
312
    """
313 314 315 316 317 318 319 320 321 322 323 324 325 326
    def __init__(self,
                 timeout=10,
                 number=4, repeat=3, min_repeat_ms=0, cooldown_interval=0.1,
                 check_correctness=False):
        super(LocalRunner, self).__init__('', None, None, 0,
                                          timeout=timeout, n_parallel=1,
                                          number=number, repeat=repeat,
                                          min_repeat_ms=min_repeat_ms,
                                          cooldown_interval=cooldown_interval,
                                          check_correctness=check_correctness)
        self.tracker = None
        self.server = None

    def set_task(self, task):
327
        # pylint: disable=import-outside-toplevel
328 329 330
        from ...rpc.tracker import Tracker
        from ...rpc.server import Server

331
        self.task = task
Bing Xu committed
332
        tracker = Tracker('0.0.0.0', port=9000, port_end=10000, silent=True)
333
        device_key = '$local$device$%d' % tracker.port
Bing Xu committed
334
        server = Server('0.0.0.0', port=9000, port_end=10000,
335 336 337
                        key=device_key,
                        use_popen=True, silent=True,
                        tracker_addr=(tracker.host, tracker.port))
338 339 340
        self.key = device_key
        self.host = tracker.host
        self.port = tracker.port
341

342 343
        super(LocalRunner, self).set_task(task)
        return server, tracker
344 345


346 347 348 349 350 351 352 353 354 355 356 357 358 359 360
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
    """Common part for building a configuration"""
    target, task, config = measure_input
    with target:
        s, args = task.instantiate(config)

        # check invalidity of template and code hash consistency
        if not config.valid():
            raise InstantiationError(config.errors)

        opts = build_option or {}
        if check_gpu:  # Add verify pass to filter out invalid configs in advance.
            opts["add_lower_pass"] = [(2, gpu_verify_pass(**check_gpu))]
        if cuda_arch:
            set_cuda_target_arch(cuda_arch)
361

362 363 364
        # if target is vta, we need to use vta build
        if hasattr(measure_input.target, 'device_name') and \
            measure_input.target.device_name == 'vta':
365
            # pylint: disable=import-outside-toplevel
366 367 368 369 370
            import vta
            func = vta.build(s, args, target_host=task.target_host)
        else:
            with build_config(**opts):
                func = build(s, args, target_host=task.target_host)
371
    return func, tuple((get_const_tuple(x.shape), x.dtype) for x in args)
372 373


374
def _wrap_build_func(build_func):
375
    """
376
    Wrap build_func to a function that can be used in measure.
377 378 379

    Parameters
    ----------
380 381
    build_func : The compilation function
        We expect fcompile to contain an attr "output_format"
382

383 384 385 386
    Returns
    -------
    wrapped_build_func : function
        The wrapped build function
387
    """
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
    if not hasattr(build_func, "output_format"):
        raise AttributeError("Expect build_func to have the attribute output_format.")
    output_format = build_func.output_format

    def _wrapped(measure_input, tmp_dir, **kwargs):
        """
        Wrapped build func.

        Parameters
        ----------
        measure_input: MeasureInput
            The input of measurement

        tmp_dir: str
            The path of temporary directory to export generated library
        """
        tic = time.time()
        try:
            filename = os.path.join(tmp_dir, "tmp_func_%0x.%s" % (
                getrandbits(64), output_format))
            # TODO(tvm-team) consider linline _build_func_common
            func, arg_info = _build_func_common(measure_input, **kwargs)
            func.export_library(filename, build_func)
        except Exception as e:  # pylint: disable=broad-except
            return BuildResult(None, None, e, time.time() - tic)
        return BuildResult(filename, arg_info, None, time.time() - tic)
    return _wrapped
415 416 417


def run_through_rpc(measure_input, build_result,
418
                    number, repeat, min_repeat_ms, cooldown_interval,
419 420
                    remote_args, ref_input=None, ref_output=None):
    """Run a generated library through rpc
421 422 423

    Parameters
    ----------
424 425 426 427
    measure_input: MeasureInput
        The raw measure input
    build_result: BuildResult
        The result returned from Builder. This contains the path to the generated library.
428 429 430
    number: int
        The number of times to run the generated code for taking average.
        We call these runs as one `repeat` of measurement.
431
    repeat : int, optional
432
        The number of times to repeat the measurement.
433
        In total, the generated code will be run (1 + number x repeat) times,
434 435 436 437 438 439 440 441 442 443
        where the first one is warm up and will be discarded.
        The returned result contains `repeat` costs,
        each of which is an average of `number` costs.
    min_repeat_ms: int, optional
        The minimum duration of one `repeat` in milliseconds.
        By default, one `repeat` contains `number` runs. If this parameter is set,
        the parameters `number` will be dynamically adjusted to meet the
        minimum duration requirement of one `repeat`.
        i.e., When the run time of one `repeat` falls below this time, the `number` parameter
        will be automatically increased.
444 445 446 447 448 449 450 451
    cooldown_interval: float
        The cool down interval between two measurements
    remote_args: Tuple
        The argument for request_remote
    ref_input: List of np.ndarray
        The reference input used for checking correctness
    ref_output: List of np.ndarray
        The reference output used for checking correctness
452
    """
453 454 455 456 457 458 459 460
    if isinstance(build_result, MeasureResult):
        return build_result

    tic = time.time()
    errno = MeasureErrorNo.NO_ERROR
    try:
        # upload built module
        remote = request_remote(*remote_args)
461 462 463
        # Program the FPGA every single time when targeting VTA
        if hasattr(measure_input.target, 'device_name') and \
            measure_input.target.device_name == 'vta':
464
            # pylint: disable=import-outside-toplevel
465 466 467
            from vta import program_fpga, reconfig_runtime
            program_fpga(remote, None)
            reconfig_runtime(remote)
468 469 470 471
        remote.upload(build_result.filename)
        func = remote.load_module(os.path.split(build_result.filename)[1])
        ctx = remote.context(str(measure_input.target), 0)
        time_f = func.time_evaluator(
472
            func.entry_name, ctx, number=number, repeat=repeat, min_repeat_ms=min_repeat_ms)
473 474 475 476 477

        # set input
        if ref_input:
            args = [nd.array(x, ctx=ctx) for x in ref_input]
        else:
478
            # create empty arrays on the remote device and copy them once.
479
            # This can avoid some memory issues that make the measurement results unreliable.
480
            args = [nd.empty(x[0], dtype=x[1], ctx=ctx) for x in build_result.arg_info]
481 482
            args = [nd.array(x, ctx=ctx) for x in args]
            ctx.sync()
483 484

        costs = time_f(*args).results
485 486 487 488 489 490

        # clean up remote files
        remote.remove(build_result.filename)
        remote.remove(os.path.splitext(build_result.filename)[0] + '.so')
        remote.remove('')

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516
        if len(costs) > 2:  # remove largest and smallest value to reduce variance
            costs = list(costs)
            costs.sort()
            costs = tuple(costs[1:-1])

        # check correctness of output
        if ref_output:
            for expected, real in zip(ref_output, args):
                if not np.allclose(expected, real.asnumpy(), rtol=1e-4):
                    logger.warning("Wrong Answer!")
                    errno = MeasureErrorNo.WRONG_ANSWER
    except TVMError as exc:
        msg = str(exc)
        if "Stack trace returned" in msg:
            msg = msg[:msg.index("Stack trace returned")]
        if "CUDA Source" in msg:
            msg = msg[:msg.index("CUDA Source")]
        costs = (RuntimeError(msg[:1024]),)
        errno = MeasureErrorNo.RUNTIME_DEVICE
    tstamp = time.time()
    time.sleep(cooldown_interval)
    return MeasureResult(costs, errno, tstamp - tic + build_result.time_cost, tstamp)


def request_remote(device_key, host=None, port=None, priority=1, timeout=60):
    """Request a remote session
517 518 519

    Parameters
    ----------
520 521 522 523 524 525 526 527 528 529 530 531
    device_key: string
        The device key of registered device in tracker
    host: host, optional
        The host address of rpc tracker.
        If is none, will use environment variable "TVM_TRACKER_HOST"
    port: int, optional
        The port of rpc tracker.
        If is none, will use environment variable "TVM_TRACKER_PORT"
    priority: int, optional
        The priority of this request, larger is more prior
    timeout: float, optional
        The timeout of this session (units: second)
532 533

    Returns
534 535
    ------
    session: RPCSession
536
    """
537 538 539
    # connect to the tracker
    host = host or os.environ['TVM_TRACKER_HOST']
    port = port or int(os.environ['TVM_TRACKER_PORT'])
540

541 542 543 544
    tracker = _rpc.connect_tracker(host, port)
    remote = tracker.request(device_key, priority=priority,
                             session_timeout=timeout)
    return remote
545 546


547
def check_remote(target, device_key, host=None, port=None, priority=100, timeout=10):
548 549
    """
    Check the availability of a remote device
550

551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574
    Parameters
    ----------
    target: Target
        The wanted compilation target
    device_key: string
        device key of registered device in tracker
    host: host, optional
        The host address of rpc tracker.
        If is none, will use environment variable "TVM_TRACKER_HOST"
    port: int, optional
        The port address of rpc tracker.
        If is none, will use environment variable "TVM_TRACKER_PORT"
    priority: int, optional
        The priority of this request, larger is more prior
    timeout: float, optional
        The timeout of this check (units: seconds).

    Returns
    -------
    available: bool
        True if can find available device
    """
    def _check():
        remote = request_remote(device_key, host, port, priority)
575 576 577
        ctx = remote.context(str(target))
        while not ctx.exist:  # wait until we get an available device
            pass
578 579 580 581
    t = threading.Thread(target=_check,)
    t.start()
    t.join(timeout)
    return not t.is_alive()
582 583


584 585 586
@register_func
def tvm_callback_cuda_compile(code):
    """use nvcc to generate ptx code for better optimization"""
587 588 589 590 591 592 593
    curr_cuda_target_arch = AutotvmGlobalScope.current.cuda_target_arch
    # e.g., target arch could be [
    #   "-gencode", "arch=compute_52,code=sm_52",
    #   "-gencode", "arch=compute_70,code=sm_70"
    # ]
    target = "fatbin" if isinstance(curr_cuda_target_arch, list) else "ptx"
    ptx = nvcc.compile_cuda(code, target=target, arch=AutotvmGlobalScope.current.cuda_target_arch)
594
    return ptx
595 596


597 598 599 600 601
def set_cuda_target_arch(arch):
    """set target architecture of nvcc compiler

    Parameters
    ----------
602
    arch: str or list
603
        The argument of nvcc -arch. (e.g. "sm_51", "sm_62")
604 605
        it can also be a count of gencode arguments pass to nvcc command line,
        e.g., ["-gencode", "arch=compute_52,code=sm_52", "-gencode", "arch=compute_70,code=sm_70"]
606 607 608 609
    """
    AutotvmGlobalScope.current.cuda_target_arch = arch


610
def gpu_verify_pass(**kwargs):
611 612
    """Verify the validity of a gpu kernel.
    This pass will check memory usage and number of threads per block.
613 614 615 616 617 618 619
    """
    def verify_pass(stmt):
        valid = ir_pass.VerifyGPUCode(stmt, kwargs)
        if not valid:
            raise InstantiationError("Skipped because of invalid gpu kernel")
        return stmt
    return verify_pass