record.py 11.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
17 18 19 20 21 22 23 24 25 26 27
# pylint: disable=superfluous-parens, redefined-outer-name, redefined-outer-name,pointless-string-statement
# pylint: disable=consider-using-enumerate,invalid-name
"""Tuning record and serialization format"""

import argparse
import base64
import logging
import multiprocessing
import pickle
import json
import time
28 29
import os
import itertools
30
from collections import OrderedDict
31
import numpy as np
32

33
from .. import build, lower, target as _target
34
from .. import __version__
35
from . import task
36
from .task import ConfigEntity, ApplyHistoryBest
37 38
from .measure import MeasureInput, MeasureResult

39 40
AUTOTVM_LOG_VERSION = 0.2
_old_version_warning = True
41
logger = logging.getLogger('autotvm')
42 43 44 45 46 47

try:  # convert unicode to str for python2
    _unicode = unicode
except NameError:
    _unicode = ()

48 49 50 51 52
try:
    _long = long
except NameError:
    _long = int

53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92

def measure_str_key(inp, include_config=True):
    """ get unique str key for MeasureInput

    Parameters
    ----------
    inp: MeasureInput
        input for the measure
    include_config: bool, optional
        whether includes config in the str key

    Returns
    -------
    key: str
        The str representation of key
    """
    config_str = str(inp.config) if include_config else ""
    return "".join([str(inp.target), inp.task.name, str(inp.task.args),
                    str(inp.task.kwargs), config_str])


def encode(inp, result, protocol='json'):
    """encode (MeasureInput, MeasureResult) pair to a string

    Parameters
    ----------
    inp: autotvm.tuner.MeasureInput
    result: autotvm.tuner.MeasureResult
        pair of input/result
    protocol: str
        log protocol, json or pickle

    Returns
    -------
    row: str
        a row in the logger file
    """

    if protocol == 'json':
        json_dict = {
93 94 95 96 97 98 99 100 101
            "input": (str(inp.target),
                      inp.task.name, inp.task.args, inp.task.kwargs),

            "config": inp.config.to_json_dict(),

            "result": (result.costs if result.error_no == 0 else (1e9,),
                       result.error_no,
                       result.all_cost,
                       result.timestamp),
102

103
            "version": AUTOTVM_LOG_VERSION,
104

105
            "tvm_version": __version__
106 107
        }
        return json.dumps(json_dict)
108
    if protocol == 'pickle':
109 110 111
        row = (str(inp.target),
               str(base64.b64encode(pickle.dumps([inp.task.name,
                                                  inp.task.args,
112
                                                  inp.task.kwargs])).decode()),
113
               str(base64.b64encode(pickle.dumps(inp.config)).decode()),
114 115 116
               str(base64.b64encode(pickle.dumps(tuple(result))).decode()),
               str(AUTOTVM_LOG_VERSION),
               str(__version__))
117
        return '\t'.join(row)
118 119

    raise RuntimeError("Invalid log protocol: " + protocol)
120 121 122 123 124 125 126


def decode(row, protocol='json'):
    """Decode encoded record string to python object

    Parameters
    ----------
127
    row : str
128
        a row in the logger file
129 130

    protocol : str
131 132 133 134
        log protocol, json or pickle

    Returns
    -------
135 136
    ret : tuple(autotvm.tuner.MeasureInput, autotvm.tuner.MeasureResult), or None
        The tuple of input and result, or None if input uses old version log format.
137 138
    """
    # pylint: disable=unused-variable
139 140
    global _old_version_warning

141 142
    if protocol == 'json':
        row = json.loads(row)
143 144 145 146 147 148 149
        if 'v' in row and row['v'] == 0.1:
            if _old_version_warning:
                logger.warning("AutoTVM log version 0.1 is no longer supported.")
                _old_version_warning = False
            return None

        tgt, task_name, task_args, task_kwargs = row["input"]
150
        tgt = _target.create(str(tgt))
151 152

        def clean_json_to_python(x):
153 154
            """1. Convert all list in x to tuple (hashable)
               2. Convert unicode to str for python2
155 156 157 158 159
            """
            if isinstance(x, list):
                return tuple([clean_json_to_python(a) for a in x])
            if isinstance(x, _unicode):
                return str(x)
160 161
            if isinstance(x, (_long, int)):
                return int(x)
162 163 164
            return x

        tsk = task.Task(clean_json_to_python(task_name), clean_json_to_python(task_args))
165
        config = ConfigEntity.from_json_dict(row["config"])
166
        inp = MeasureInput(tgt, tsk, config)
167 168
        result = MeasureResult(*[tuple(x) if isinstance(x, list) else x for x in row["result"]])
        config.cost = np.mean(result.costs)
169 170

        return inp, result
171
    if protocol == 'pickle':
172
        items = row.split("\t")
173 174 175 176 177
        if len(items) == 4:
            if _old_version_warning:
                logger.warning("AutoTVM log version 0.1 is no longer supported.")
                _old_version_warning = False
            return None
178
        tgt = _target.create(items[0])
179 180
        task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
        config = pickle.loads(base64.b64decode(items[2].encode()))
181 182
        result = MeasureResult(*pickle.loads(base64.b64decode(items[3].encode())))
        config.cost = np.mean(result.costs)
183 184

        tsk = task.Task(task_tuple[0], task_tuple[1])
185
        return MeasureInput(tgt, tsk, config), result
186 187

    raise RuntimeError("Invalid log protocol: " + protocol)
188

189

190 191 192 193 194 195 196 197 198 199 200 201 202 203
def load_from_file(filename):
    """Generator: load records from file.
    This is a generator that yields the records.

    Parameters
    ----------
    filename: str

    Yields
    ------
    input: autotvm.tuner.MeasureInput
    result: autotvm.tuner.MeasureResult
    """
    for row in open(filename):
204
        if row and not row.startswith('#'):
205 206 207 208
            ret = decode(row)
            if ret is None:
                continue
            inp, res = ret
209 210 211 212 213 214
            # Avoid loading the record with an empty config. The TOPI schedule with no entities
            # will result in an empty entity map (e.g., depthwise_conv2d_nchw on x86).
            # Using an empty config will cause problems when applying alter op like NCHW to NCHWc.
            if not inp.config._entity_map:
                continue
            yield (inp, res)
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230


def split_workload(in_file, clean=True):
    """Split a log file into separate files, each of which contains only a single workload
    This function can also delete duplicated records in log file

    Parameters
    ----------
    in_file: str
        input filename
    clean: bool
        whether delete duplicated items
    """
    tic = time.time()
    lines = list(open(in_file).readlines())

231
    logger.info("start converting...")
232
    pool = multiprocessing.Pool()
233
    lines = [rec for rec in pool.map(decode, lines) if rec is not None]
234
    logger.info("map done %.2f", time.time() - tic)
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255

    wkl_dict = OrderedDict()
    for inp, res in lines:
        wkl = measure_str_key(inp, False)
        if wkl not in wkl_dict:
            wkl_dict[wkl] = []
        wkl_dict[wkl].append([inp, res])

    if clean:
        for i, (k, v) in enumerate(wkl_dict.items()):
            # clean duplicated items
            added = set()
            cleaned = []
            for inp, res in v:
                str_key = measure_str_key(inp)
                if str_key in added:
                    continue
                added.add(str_key)
                cleaned.append([inp, res])

            # write to file
256
            logger.info("Key: %s\tValid: %d\tDup: %d\t", k, len(cleaned), len(v) - len(cleaned))
257 258 259 260 261
            with open(args.i + ".%03d.wkl" % i, 'w') as fout:
                for inp, res in cleaned:
                    fout.write(encode(inp, res) + '\n')
    else:
        for i, (k, v) in enumerate(wkl_dict.items()):
262
            logger.info("Key: %s\tNum: %d", k, len(v))
263 264 265 266
            with open(args.i + ".%03d.wkl" % i, 'w') as fout:
                for inp, res in v:
                    fout.write(encode(inp, res) + '\n')

267 268 269 270
def pick_best(in_file, out_file):
    """
    Pick best entries from a file and store it to another file.
    This distill the useful log entries from a large log file.
271 272
    If out_file already exists, the best entries from both
    in_file and out_file will be saved.
273 274 275 276 277

    Parameters
    ----------
    in_file: str
        The filename of input
278
    out_file: str or file
279 280
        The filename of output
    """
281 282 283 284 285 286
    context = load_from_file(in_file)
    if os.path.isfile(out_file):
        out_context = load_from_file(out_file)
        context = itertools.chain(context, out_context)
    context, context_clone = itertools.tee(context)
    best_context = ApplyHistoryBest(context)
287 288 289 290 291 292 293 294
    best_set = set()

    for v in best_context.best_by_model.values():
        best_set.add(measure_str_key(v[0]))

    for v in best_context.best_by_targetkey.values():
        best_set.add(measure_str_key(v[0]))

295
    logger.info("Extract %d best records from the %s", len(best_set), in_file)
296
    fout = open(out_file, 'w') if isinstance(out_file, str) else out_file
297

298
    for inp, res in context_clone:
299 300
        if measure_str_key(inp) in best_set:
            fout.write(encode(inp, res) + "\n")
301
            best_set.remove(measure_str_key(inp))
302 303 304 305 306 307

"""
Usage:
This record executable module has three modes.

* Print log file in readable format
308
e.g. python -m tvm.autotvm.record --mode read --i collect_conv.log --begin 0 --end 5 --ir --code
309 310

* Extract history best from a large log file
311
e.g. python -m tvm.autotvm.record --mode pick --i collect.log
312 313

* Split a log file into separate files, each of which contains only a single wkl
314
e.g. python -m tvm.autotvm.record --mode split --i collect.log
315 316 317
"""
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
318
    parser.add_argument("--mode", choices=['read', 'pick', 'split'], default='read')
319 320 321 322 323 324 325 326
    parser.add_argument("--i", type=str, help="input file")
    parser.add_argument("--o", type=str, default=None, help='output file')
    parser.add_argument("--begin", type=int, default=0)
    parser.add_argument("--end", type=int, default=5)
    parser.add_argument("--ir", action='store_true')
    parser.add_argument("--code", action='store_true')

    args = parser.parse_args()
327
    logging.basicConfig(level=logging.INFO)
328

329 330 331
    if args.mode == 'pick':
        args.o = args.o or args.i + ".best.log"
        pick_best(args.i, args.o)
332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    elif args.mode == 'read':
        for i, (inp, result) in enumerate(load_from_file(args.i)):
            if args.begin <= i < args.end:
                with inp.target:
                    s, arg_bufs = inp.task.instantiate(inp.config)

                print("")
                print(inp.target, inp.task, inp.config)
                print(result)

                if args.ir:
                    with inp.target:
                        print(lower(s, arg_bufs, simple_mode=True))

                if args.code:
                    with inp.target:
                        func = build(s, arg_bufs)
                        print(func.imported_modules[0].get_source())
    elif args.mode == 'split':
        split_workload(args.i)