8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18
# pylint: disable=pointless-string-statement,consider-using-enumerate,invalid-name
"""User facing API for specifying how to measure the generated code"""
import multiprocessing
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
from collections import namedtuple

class MeasureInput(namedtuple("MeasureInput", ["target", "task", "config"])):
    Stores all the necessary inputs for a measurement.

    target :
        The target device
    task : task.Task
        Task function
    config : ConfigEntity
        Specific configuration.


37 38 39 40 41 42 43
class MeasureResult(namedtuple("MeasureResult", ["costs", "error_no", "all_cost", "timestamp"])):
    Stores all the results of a measurement

    costs: Array of float or Array of Exception
44 45
        If no error occurs during measurement, it is an array of measured running times.
        If an error occurs during measurement, it is an array of the exception objections.
46 47 48 49 50 51 52 53
    error_no: int
        Denote error type, defined by MeasureErrorNo
    all_cost: float
        All cost of this measure, including rpc, compilation, test runs
    timestamp: float
        The absolute time stamp when we finish measurement.


55 56 57
class MeasureErrorNo(object):
    """Error type for MeasureResult"""
    NO_ERROR = 0              # no error
    INSTANTIATION_ERROR = 1   # actively detected error in instantiating a template with a config
    COMPILE_HOST = 2          # error when compiling code on host (e.g.
    COMPILE_DEVICE = 3        # error when compiling code on device (e.g. OpenCL JIT on the device)
61 62
    RUNTIME_DEVICE = 4        # error when run program on device
    WRONG_ANSWER = 5          # answer is wrong when compared to a golden output
63 64 65 66
    BUILD_TIMEOUT = 6         # timeout during compilation
    RUN_TIMEOUT = 7           # timeout during run
    UNKNOWN_ERROR = 8         # unknown error


68 69
class Builder(object):
    """Builder that builds programs in tuning

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    timeout: float, optional
        The timeout of a build task
    n_parallel: int, optional
        The number of tasks submitted in parallel
        By default it will use all cpu cores
    def __init__(self, timeout=10, n_parallel=None):
        self.timeout = timeout
        self.n_parallel = n_parallel or multiprocessing.cpu_count()
        self.build_kwargs = {}
        self.task = None

    def set_task(self, task, build_kwargs=None):
        Initialize for a new tuning task

        task: Task
            The tuning task
        build_kwargs: dict, optional
            The additional kwargs for build function
        self.task = task
        self.build_kwargs = build_kwargs

    def build(self, measure_inputs):
        """Build programs

        measure_inputs: List of MeasureInput
            The measure input

        build_results: List of BuildResult
            The build result.
        raise NotImplementedError()

class Runner(object):
    """Runner that runs and measures the time cost of a generated program in tuning
117 118 119

120 121
    timeout: float, optional
        The timeout of a build task
    n_parallel: int, optional
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        The number of tasks submitted in parallel
        By default it will use all cpu cores
    def __init__(self, timeout=5, n_parallel=None):
        self.timeout = timeout
        self.n_parallel = n_parallel or multiprocessing.cpu_count()
        self.task = None

    def set_task(self, task):
        Initialize for a new tuning task

        task: Task
            The tuning task
        self.task = task

    def get_build_kwargs(self):
        Get device specific build arguments (e.g. maximum shared memory size)

        kwargs: dict
            The additional keyword arguments
        raise NotImplementedError()

    def run(self, measure_inputs, build_results):
        """Run amd measure built programs

        measure_inputs: List of MeasureInput
            The raw measure input
        build_results: List of BuildResults
            The build results

        measure_results: List of MeasureResult
            The final results of measurement
        raise NotImplementedError()

def measure_option(builder, runner):
    Set options for measure. To measure a config, we will build it and run it.
    So we have to set options for these two steps.
    They have their own options on timeout, parallel, etc.

    builder: Builder
        Specify how to build programs
    runner: Runner
        Specify how to run programs
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205

    # example setting for using local devices
    >>> measure_option = autotvm.measure_option(
    >>>     builder=autotvm.LocalBuilder(),      # use all local cpu cores for compilation
    >>>     runner=autotvm.LocalRunner(          # measure them sequentially
    >>>         number=10,
    >>>         timeout=5)
    >>> )

    # example setting for using remote devices
    >>> measure_option = autotvm.measure_option(
    >>>    builder=autotvm.LocalBuilder(),  # use all local cpu cores for compilation
    >>>    runner=autotvm.RPCRunner(
    >>>        'rasp3b', 'locahost', 9190, # device key, host and port of the rpc tracker
    >>>        number=4,
    >>>        timeout=4) # timeout of a run on the device. RPC request waiting time is excluded.

    To make measurement results accurate, you should pick the correct value for the argument
206 207 208 209
    `number` and `repeat` in Runner(). Some devices need a certain minimum running time to
    "warm up," such as GPUs that need time to reach a performance power state.
    Using `min_repeat_ms` can dynamically adjusts `number`, so it is recommended.
    The typical value for NVIDIA GPU is 150 ms.
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
    from .measure_methods import LocalBuilder, LocalRunner

    if isinstance(builder, str):
        if builder == 'local':
            builder = LocalBuilder()
            raise ValueError("Invalid builder: " + builder)

    if isinstance(runner, str):
        if runner == 'local':
            runner = LocalRunner()
            raise ValueError("Invalid runner: " + runner)

    opt = {
        'builder': builder,
        'runner': runner,

    return opt

def create_measure_batch(task, option):
    """Get a standard measure_batch function.

    task: tvm.autotvm.task.Task
        The tuning task
    option: dict
        The option for measuring generated code.
        You should use the return value of function :any:`measure_option` for this argument.
243 244 245

246 247
    measure_batch: callable
        a callback function to measure a batch of configs
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
    builder = option['builder']
    runner = option['runner']

    attach_objects = runner.set_task(task)

    # feed device related information from runner to builder
    # (e.g. max shared memory for validity checking)
    build_kwargs = runner.get_build_kwargs()
    builder.set_task(task, build_kwargs)

    def measure_batch(measure_inputs):
        build_results =
        results =, build_results)
        return results

    measure_batch.n_parallel = builder.n_parallel
    measure_batch.attach_objects = attach_objects
    return measure_batch