Commit f3abb3d8 by Neo Chien Committed by Tianqi Chen

[TVM][AutoTVM] cast filepath arguments to string (#3968)

parent de123760
...@@ -41,6 +41,7 @@ from .space import FallbackConfigEntity ...@@ -41,6 +41,7 @@ from .space import FallbackConfigEntity
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
class DispatchContext(object): class DispatchContext(object):
""" """
Base class of dispatch context. Base class of dispatch context.
...@@ -281,8 +282,12 @@ class ApplyHistoryBest(DispatchContext): ...@@ -281,8 +282,12 @@ class ApplyHistoryBest(DispatchContext):
Each row of this file is an encoded record pair. Each row of this file is an encoded record pair.
Otherwise, it is an iterator. Otherwise, it is an iterator.
""" """
from pathlib import Path
from ..record import load_from_file from ..record import load_from_file
if isinstance(records, Path):
records = str(records)
if isinstance(records, str): if isinstance(records, str):
records = load_from_file(records) records = load_from_file(records)
if not records: if not records:
...@@ -404,8 +409,10 @@ class FallbackContext(DispatchContext): ...@@ -404,8 +409,10 @@ class FallbackContext(DispatchContext):
key = (str(target), workload) key = (str(target), workload)
self.memory[key] = cfg self.memory[key] = cfg
DispatchContext.current = FallbackContext() DispatchContext.current = FallbackContext()
def clear_fallback_cache(target, workload): def clear_fallback_cache(target, workload):
"""Clear fallback cache. Pass the same argument as _query_inside to this function """Clear fallback cache. Pass the same argument as _query_inside to this function
to clean the cache. to clean the cache.
...@@ -426,6 +433,7 @@ def clear_fallback_cache(target, workload): ...@@ -426,6 +433,7 @@ def clear_fallback_cache(target, workload):
context = context._old_ctx context = context._old_ctx
context.clear_cache(target, workload) context.clear_cache(target, workload)
class ApplyGraphBest(DispatchContext): class ApplyGraphBest(DispatchContext):
"""Load the graph level tuning optimal schedules. """Load the graph level tuning optimal schedules.
......
...@@ -26,6 +26,7 @@ from .. import record ...@@ -26,6 +26,7 @@ from .. import record
logger = logging.getLogger('autotvm') logger = logging.getLogger('autotvm')
def log_to_file(file_out, protocol='json'): def log_to_file(file_out, protocol='json'):
"""Log the tuning records into file. """Log the tuning records into file.
The rows of the log are stored in the format of autotvm.record.encode. The rows of the log are stored in the format of autotvm.record.encode.
...@@ -51,6 +52,11 @@ def log_to_file(file_out, protocol='json'): ...@@ -51,6 +52,11 @@ def log_to_file(file_out, protocol='json'):
else: else:
for inp, result in zip(inputs, results): for inp, result in zip(inputs, results):
file_out.write(record.encode(inp, result, protocol) + "\n") file_out.write(record.encode(inp, result, protocol) + "\n")
from pathlib import Path
if isinstance(file_out, Path):
file_out = str(file_out)
return _callback return _callback
......
...@@ -107,6 +107,10 @@ class Module(ModuleBase): ...@@ -107,6 +107,10 @@ class Module(ModuleBase):
kwargs : dict, optional kwargs : dict, optional
Additional arguments passed to fcompile Additional arguments passed to fcompile
""" """
from pathlib import Path
if isinstance(file_name, Path):
file_name = str(file_name)
if self.type_key == "stackvm": if self.type_key == "stackvm":
if not file_name.endswith(".stackvm"): if not file_name.endswith(".stackvm"):
raise ValueError("Module[%s]: can only be saved as stackvm format." raise ValueError("Module[%s]: can only be saved as stackvm format."
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment