Commit 062f8cc4 by Cody Hao Yu Committed by Haichen Shen

[AutoTVM] Fix database APIs (#3821)

* [AutoTVM] Fix database APIs

* Refactor the byte conversion
parent 19b8b3a4
...@@ -117,12 +117,12 @@ class RedisDatabase(Database): ...@@ -117,12 +117,12 @@ class RedisDatabase(Database):
self.db.set(key, value) self.db.set(key, value)
def get(self, key): def get(self, key):
return self.db.get(key) current = self.db.get(key)
return current.decode() if isinstance(current, bytes) else current
def load(self, inp, get_all=False): def load(self, inp, get_all=False):
current = self.get(measure_str_key(inp)) current = self.get(measure_str_key(inp))
if current is not None: if current is not None:
current = str(current)
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)] records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
results = [rec[1] for rec in records] results = [rec[1] for rec in records]
if get_all: if get_all:
...@@ -142,29 +142,31 @@ class RedisDatabase(Database): ...@@ -142,29 +142,31 @@ class RedisDatabase(Database):
def filter(self, func): def filter(self, func):
""" """
Dump all of the records for a particular target Dump all of the records that match the given rule
Parameters Parameters
---------- ----------
func: callable func: callable
The signature of the function is bool (MeasureInput, Array of MeasureResult) The signature of the function is (MeasureInput, [MeasureResult]) -> bool
Returns Returns
------- -------
list of records (inp, result) matching the target list of records in tuple (MeasureInput, MeasureResult) matching the rule
Examples Examples
-------- --------
get records for a target get records for a target
>>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys) >>> db.filter(lambda inp, resulst: "cuda" in inp.target.keys)
get records with errors
>>> db.filter(lambda inp, results: any(r.error_no != 0 for r in results))
""" """
matched_records = list() matched_records = list()
# may consider filtering in iterator in the future # may consider filtering in iterator in the future
for key in self.db: for key in self.db.keys():
current = self.get(key) current = self.get(key)
try: try:
records = [decode(x) for x in current.spilt(RedisDatabase.MAGIC_SPLIT)] records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
except TypeError: # got a badly formatted/old format record except TypeError: # got a badly formatted/old format record
continue continue
inps, results = zip(*records) inps, results = zip(*records)
...@@ -190,8 +192,5 @@ class DummyDatabase(RedisDatabase): ...@@ -190,8 +192,5 @@ class DummyDatabase(RedisDatabase):
def set(self, key, value): def set(self, key, value):
self.db[key] = value self.db[key] = value
def get(self, key):
return self.db.get(key)
def flush(self): def flush(self):
self.db = {} self.db = {}
...@@ -99,8 +99,20 @@ def test_db_latest_all(): ...@@ -99,8 +99,20 @@ def test_db_latest_all():
assert encode(inp1, load4[1]) == encode(inp1, res2) assert encode(inp1, load4[1]) == encode(inp1, res2)
assert encode(inp1, load4[2]) == encode(inp1, res3) assert encode(inp1, load4[2]) == encode(inp1, res3)
def test_db_filter():
logging.info("test db filter ...")
records = get_sample_records(5)
_db = database.DummyDatabase()
_db.flush()
for inp, result in records:
_db.save(inp, result)
records = _db.filter(lambda inp, ress: any(r.costs[0] <= 2 for r in ress))
assert len(records) == 2
if __name__ == '__main__': if __name__ == '__main__':
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
test_save_load() test_save_load()
test_db_hash() test_db_hash()
test_db_latest_all() test_db_latest_all()
test_db_filter()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment