Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
f3abb3d8
Commit
f3abb3d8
authored
Sep 18, 2019
by
Neo Chien
Committed by
Tianqi Chen
Sep 18, 2019
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[TVM][AutoTVM] cast filepath arguments to string (#3968)
parent
de123760
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
0 deletions
+18
-0
python/tvm/autotvm/task/dispatcher.py
+8
-0
python/tvm/autotvm/tuner/callback.py
+6
-0
python/tvm/module.py
+4
-0
No files found.
python/tvm/autotvm/task/dispatcher.py
View file @
f3abb3d8
...
@@ -41,6 +41,7 @@ from .space import FallbackConfigEntity
...
@@ -41,6 +41,7 @@ from .space import FallbackConfigEntity
logger
=
logging
.
getLogger
(
'autotvm'
)
logger
=
logging
.
getLogger
(
'autotvm'
)
class
DispatchContext
(
object
):
class
DispatchContext
(
object
):
"""
"""
Base class of dispatch context.
Base class of dispatch context.
...
@@ -281,8 +282,12 @@ class ApplyHistoryBest(DispatchContext):
...
@@ -281,8 +282,12 @@ class ApplyHistoryBest(DispatchContext):
Each row of this file is an encoded record pair.
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
Otherwise, it is an iterator.
"""
"""
from
pathlib
import
Path
from
..record
import
load_from_file
from
..record
import
load_from_file
if
isinstance
(
records
,
Path
):
records
=
str
(
records
)
if
isinstance
(
records
,
str
):
if
isinstance
(
records
,
str
):
records
=
load_from_file
(
records
)
records
=
load_from_file
(
records
)
if
not
records
:
if
not
records
:
...
@@ -404,8 +409,10 @@ class FallbackContext(DispatchContext):
...
@@ -404,8 +409,10 @@ class FallbackContext(DispatchContext):
key
=
(
str
(
target
),
workload
)
key
=
(
str
(
target
),
workload
)
self
.
memory
[
key
]
=
cfg
self
.
memory
[
key
]
=
cfg
DispatchContext
.
current
=
FallbackContext
()
DispatchContext
.
current
=
FallbackContext
()
def
clear_fallback_cache
(
target
,
workload
):
def
clear_fallback_cache
(
target
,
workload
):
"""Clear fallback cache. Pass the same argument as _query_inside to this function
"""Clear fallback cache. Pass the same argument as _query_inside to this function
to clean the cache.
to clean the cache.
...
@@ -426,6 +433,7 @@ def clear_fallback_cache(target, workload):
...
@@ -426,6 +433,7 @@ def clear_fallback_cache(target, workload):
context
=
context
.
_old_ctx
context
=
context
.
_old_ctx
context
.
clear_cache
(
target
,
workload
)
context
.
clear_cache
(
target
,
workload
)
class
ApplyGraphBest
(
DispatchContext
):
class
ApplyGraphBest
(
DispatchContext
):
"""Load the graph level tuning optimal schedules.
"""Load the graph level tuning optimal schedules.
...
...
python/tvm/autotvm/tuner/callback.py
View file @
f3abb3d8
...
@@ -26,6 +26,7 @@ from .. import record
...
@@ -26,6 +26,7 @@ from .. import record
logger
=
logging
.
getLogger
(
'autotvm'
)
logger
=
logging
.
getLogger
(
'autotvm'
)
def
log_to_file
(
file_out
,
protocol
=
'json'
):
def
log_to_file
(
file_out
,
protocol
=
'json'
):
"""Log the tuning records into file.
"""Log the tuning records into file.
The rows of the log are stored in the format of autotvm.record.encode.
The rows of the log are stored in the format of autotvm.record.encode.
...
@@ -51,6 +52,11 @@ def log_to_file(file_out, protocol='json'):
...
@@ -51,6 +52,11 @@ def log_to_file(file_out, protocol='json'):
else
:
else
:
for
inp
,
result
in
zip
(
inputs
,
results
):
for
inp
,
result
in
zip
(
inputs
,
results
):
file_out
.
write
(
record
.
encode
(
inp
,
result
,
protocol
)
+
"
\n
"
)
file_out
.
write
(
record
.
encode
(
inp
,
result
,
protocol
)
+
"
\n
"
)
from
pathlib
import
Path
if
isinstance
(
file_out
,
Path
):
file_out
=
str
(
file_out
)
return
_callback
return
_callback
...
...
python/tvm/module.py
View file @
f3abb3d8
...
@@ -107,6 +107,10 @@ class Module(ModuleBase):
...
@@ -107,6 +107,10 @@ class Module(ModuleBase):
kwargs : dict, optional
kwargs : dict, optional
Additional arguments passed to fcompile
Additional arguments passed to fcompile
"""
"""
from
pathlib
import
Path
if
isinstance
(
file_name
,
Path
):
file_name
=
str
(
file_name
)
if
self
.
type_key
==
"stackvm"
:
if
self
.
type_key
==
"stackvm"
:
if
not
file_name
.
endswith
(
".stackvm"
):
if
not
file_name
.
endswith
(
".stackvm"
):
raise
ValueError
(
"Module[
%
s]: can only be saved as stackvm format."
raise
ValueError
(
"Module[
%
s]: can only be saved as stackvm format."
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment