# pylint: disable=consider-using-enumerate,invalid-name """ Database of MeasureInput/MeasureResult pair. This can be used for replaying measurement. """ import os from .record import encode, decode, measure_str_key class Database(object): """ Base class for a record database object. """ def load(self, inp, get_all=False): """ Load a result based on an input's string key Parameters ---------- inp: MeasureInput to be translated into key for RedisDB get_all: bool, optional Whether the latest result (or all matching results) should be returned Returns ------- rec: MeasureResult if previously saved, otherwise None """ raise NotImplementedError() def save(self, inp, res, extend=False): """ Save a result based on an input's string key Parameters ---------- inp: MeasureInput to be translated into key for RedisDB res: MeasureResult to associate with key extend: Whether to extend existing MeasureResults if they exist """ raise NotImplementedError() def filter_inputs(db, measure_inputs, retry=False): """ Filter a measure_inputs batch based on saved db results Parameters ---------- db: Database database object measure_inputs: Array of MeasureInput measure_inputs as expected in measure_batch retry: bool whether to retry if the saved result is a failure Returns ------- partial_results: Array of MeasureResult a full list of result, where None denotes no corresponding saved result unsaved: Array of MeasureInput a list that only contains unsaved inputs """ partial_results = list() unsaved = list() for inp in measure_inputs: res = db.load(inp) if res is None or (retry and res.error_no != 0): unsaved.append(inp) partial_results.append(None) else: partial_results.append(res) return partial_results, unsaved class RedisDatabase(Database): """ Redis version of record database """ REDIS_PROD = 15 REDIS_LOCA = 14 REDIS_TEST = 13 # for unit test REDIS_NIGHT_TEMP = 12 # for nightly report (will be flushed after every workload) MAGIC_SPLIT = "$" def __init__(self, db_index=REDIS_PROD): import redis if db_index == RedisDatabase.REDIS_TEST: host = 'localhost' else: host = os.environ.get('TVM_FLEET_HOST') self.db = redis.StrictRedis(host=host, port=6379, db=db_index) self.db_index = db_index def set(self, key, value): self.db.set(key, value) def get(self, key): return self.db.get(key) def load(self, inp, get_all=False): current = self.get(measure_str_key(inp)) if current is not None: current = str(current) records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] results = [rec[1] for rec in records] if get_all: return results return max(results, key=lambda result: result.timestamp) return current def save(self, inp, res, extend=False): current = self.get(measure_str_key(inp)) if not extend or current is None: self.set(measure_str_key(inp), RedisDatabase.MAGIC_SPLIT.join([encode(inp, res)])) else: current = current.split(RedisDatabase.MAGIC_SPLIT) self.set(measure_str_key(inp), RedisDatabase.MAGIC_SPLIT.join(current + [encode(inp, res)])) def filter(self, func): """ Dump all of the records for a particular target Parameters ---------- func: callable The signature of the function is bool (MeasureInput, Array of MeasureResult) Returns ------- list of records (inp, result) matching the target Examples -------- get records for a target >>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys) """ matched_records = list() # may consider filtering in iterator in the future for key in self.db: current = self.get(key) try: records = [decode(x) for x in current.spilt(RedisDatabase.MAGIC_SPLIT)] except TypeError: # got a badly formatted/old format record continue inps, results = zip(*records) inp = inps[0] if not func(inp, results): continue result = max(results, key=lambda res: res.timestamp) matched_records.append((inp, result)) return matched_records def flush(self): self.db.flushdb() class DummyDatabase(RedisDatabase): """ A database based on python dictionary for testing. """ def __init__(self): # pylint: disable=super-init-not-called self.db = {} def set(self, key, value): self.db[key] = value def get(self, key): return self.db.get(key) def flush(self): self.db = {}