Unverified Commit f5c9bc93 by Andrew Reusch Committed by GitHub

Customize SI prefix in logging (#5411)

* Customize SI prefix in logging

* Include unit test
parent 8f9796bd
...@@ -23,6 +23,7 @@ import logging ...@@ -23,6 +23,7 @@ import logging
import numpy as np import numpy as np
from .. import record from .. import record
from ..util import format_si_prefix
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
...@@ -105,7 +106,7 @@ class Monitor(object): ...@@ -105,7 +106,7 @@ class Monitor(object):
return np.array(self.timestamps) return np.array(self.timestamps)
def progress_bar(total, prefix=''): def progress_bar(total, prefix='', si_prefix='G'):
"""Display progress bar for tuning """Display progress bar for tuning
Parameters Parameters
...@@ -114,6 +115,8 @@ def progress_bar(total, prefix=''): ...@@ -114,6 +115,8 @@ def progress_bar(total, prefix=''):
The total number of trials The total number of trials
prefix: str prefix: str
The prefix of output message The prefix of output message
si_prefix: str
SI prefix for flops
""" """
class _Context(object): class _Context(object):
"""Context to store local variables""" """Context to store local variables"""
...@@ -130,6 +133,9 @@ def progress_bar(total, prefix=''): ...@@ -130,6 +133,9 @@ def progress_bar(total, prefix=''):
ctx = _Context() ctx = _Context()
tic = time.time() tic = time.time()
# Validate si_prefix argument
format_si_prefix(0, si_prefix)
if logger.level < logging.DEBUG: # only print progress bar in non-debug mode if logger.level < logging.DEBUG: # only print progress bar in non-debug mode
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) ' sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) '
'| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic)) '| %.2f s' % (prefix, 0, 0, 0, total, time.time() - tic))
...@@ -147,10 +153,11 @@ def progress_bar(total, prefix=''): ...@@ -147,10 +153,11 @@ def progress_bar(total, prefix=''):
ctx.cur_flops = flops ctx.cur_flops = flops
ctx.best_flops = tuner.best_flops ctx.best_flops = tuner.best_flops
sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f GFLOPS | Progress: (%d/%d) ' sys.stdout.write('\r%s Current/Best: %7.2f/%7.2f %sFLOPS | Progress: (%d/%d) '
'| %.2f s' % '| %.2f s' %
(prefix, ctx.cur_flops/1e9, ctx.best_flops/1e9, ctx.ct, ctx.total, (prefix, format_si_prefix(ctx.cur_flops, si_prefix),
time.time() - tic)) format_si_prefix(ctx.best_flops, si_prefix), si_prefix,
ctx.ct, ctx.total, time.time() - tic))
sys.stdout.flush() sys.stdout.flush()
return _callback return _callback
...@@ -21,6 +21,7 @@ import logging ...@@ -21,6 +21,7 @@ import logging
import numpy as np import numpy as np
from ..measure import MeasureInput, create_measure_batch from ..measure import MeasureInput, create_measure_batch
from ..util import format_si_prefix
from ..env import GLOBAL_SCOPE from ..env import GLOBAL_SCOPE
...@@ -87,7 +88,7 @@ class Tuner(object): ...@@ -87,7 +88,7 @@ class Tuner(object):
""" """
def tune(self, n_trial, measure_option, early_stopping=None, callbacks=()): def tune(self, n_trial, measure_option, early_stopping=None, callbacks=(), si_prefix='G'):
"""Begin tuning """Begin tuning
Parameters Parameters
...@@ -104,6 +105,8 @@ class Tuner(object): ...@@ -104,6 +105,8 @@ class Tuner(object):
(Tuner, List of MeasureInput, List of MeasureResult) (Tuner, List of MeasureInput, List of MeasureResult)
with no return value. These callback functions will be called on with no return value. These callback functions will be called on
every measurement pair. See autotvm/tuner/callback.py for some examples. every measurement pair. See autotvm/tuner/callback.py for some examples.
si_prefix: str
One of tvm.autotvm.util.SI_PREFIXES. The SI prefix to use when reporting FLOPS.
""" """
measure_batch = create_measure_batch(self.task, measure_option) measure_batch = create_measure_batch(self.task, measure_option)
n_parallel = getattr(measure_batch, 'n_parallel', 1) n_parallel = getattr(measure_batch, 'n_parallel', 1)
...@@ -111,6 +114,9 @@ class Tuner(object): ...@@ -111,6 +114,9 @@ class Tuner(object):
self.n_trial = n_trial self.n_trial = n_trial
self.early_stopping = early_stopping self.early_stopping = early_stopping
# Validate si_prefix arg
format_si_prefix(0, si_prefix)
old_level = logger.level old_level = logger.level
GLOBAL_SCOPE.in_tuning = True GLOBAL_SCOPE.in_tuning = True
...@@ -140,9 +146,9 @@ class Tuner(object): ...@@ -140,9 +146,9 @@ class Tuner(object):
self.best_measure_pair = (inp, res) self.best_measure_pair = (inp, res)
self.best_iter = i + k self.best_iter = i + k
logger.debug("No: %d\tGFLOPS: %.2f/%.2f\tresult: %s\t%s", logger.debug("No: %d\t%sFLOPS: %.2f/%.2f\tresult: %s\t%s",
i + k + 1, flops / 1e9, self.best_flops / 1e9, i + k + 1, si_prefix, format_si_prefix(flops, si_prefix),
res, config) format_si_prefix(self.best_flops, si_prefix), res, config)
i += len(results) i += len(results)
self.ttl = min(early_stopping + self.best_iter, n_trial) - i self.ttl = min(early_stopping + self.best_iter, n_trial) - i
......
...@@ -188,3 +188,12 @@ def get_const_tuple(in_tuple): ...@@ -188,3 +188,12 @@ def get_const_tuple(in_tuple):
else: else:
ret.append(get_const_int(elem)) ret.append(get_const_int(elem))
return tuple(ret) return tuple(ret)
SI_PREFIXES = 'yzafpn\xb5m kMGTPEZY'
YOCTO_EXP10 = -24
def format_si_prefix(x, si_prefix):
exp10 = 10 ** (SI_PREFIXES.index(si_prefix) * 3 + YOCTO_EXP10)
return float(x) / exp10
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from numpy import isclose
import random
from tvm.autotvm import util
SI_PREFIXES = 'yzafpn\xb5m kMGTPEZY'
def test_format_si_prefix():
# test float conversion
assert util.format_si_prefix(1024, 'k') == 1.024
for i, prefix in enumerate(SI_PREFIXES):
integer, decimal = random.randint(0, 1000), random.randint(0, 1000)
exp = -24 + 3 * i # 0th prefix (yocto) is 10^-24
number = integer * (10 ** exp) + decimal * (10 ** (exp - 3))
expected = (integer + decimal / 1000)
assert isclose(util.format_si_prefix(number, prefix), expected)
assert util.format_si_prefix(0, 'y') == 0
if __name__ == '__main__':
test_format_si_prefix()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment