Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
136061dc
Commit
136061dc
authored
Aug 03, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Aug 03, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[AUTOTVM] Improve tutorial and logging (#1544)
parent
33606741
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
181 additions
and
97 deletions
+181
-97
python/tvm/autotvm/measure/__init__.py
+1
-1
python/tvm/autotvm/measure/measure_methods.py
+44
-4
python/tvm/autotvm/record.py
+7
-6
python/tvm/autotvm/task/dispatcher.py
+5
-1
python/tvm/autotvm/tophub.py
+4
-2
python/tvm/autotvm/tuner/callback.py
+6
-3
python/tvm/autotvm/tuner/sa_model_optimizer.py
+6
-4
python/tvm/autotvm/tuner/tuner.py
+16
-6
python/tvm/autotvm/tuner/xgboost_cost_model.py
+7
-5
python/tvm/autotvm/util.py
+4
-3
python/tvm/rpc/base.py
+5
-6
python/tvm/rpc/proxy.py
+3
-2
python/tvm/rpc/server.py
+19
-31
python/tvm/rpc/tracker.py
+11
-9
tutorials/autotvm/tune_conv2d_cuda.py
+2
-1
tutorials/autotvm/tune_nnvm_arm.py
+38
-11
tutorials/autotvm/tune_simple_template.py
+3
-2
No files found.
python/tvm/autotvm/measure/__init__.py
View file @
136061dc
"""Distributed executor infrastructure to scale up the tuning"""
"""Distributed executor infrastructure to scale up the tuning"""
from
.measure
import
MeasureInput
,
MeasureResult
,
MeasureErrorNo
,
measure_option
from
.measure
import
MeasureInput
,
MeasureResult
,
MeasureErrorNo
,
measure_option
from
.measure_methods
import
request_remote
,
create_measure_batch
,
use_rpc
from
.measure_methods
import
request_remote
,
c
heck_remote
,
c
reate_measure_batch
,
use_rpc
from
.local_executor
import
LocalExecutor
from
.local_executor
import
LocalExecutor
from
.executor
import
Future
,
Executor
from
.executor
import
Future
,
Executor
python/tvm/autotvm/measure/measure_methods.py
View file @
136061dc
...
@@ -9,6 +9,7 @@ import logging
...
@@ -9,6 +9,7 @@ import logging
import
os
import
os
import
time
import
time
from
random
import
getrandbits
from
random
import
getrandbits
import
threading
import
numpy
as
np
import
numpy
as
np
...
@@ -23,6 +24,7 @@ from ..task.space import InstantiationError
...
@@ -23,6 +24,7 @@ from ..task.space import InstantiationError
from
.measure
import
MeasureResult
,
MeasureErrorNo
from
.measure
import
MeasureResult
,
MeasureErrorNo
from
.local_executor
import
LocalExecutor
from
.local_executor
import
LocalExecutor
logger
=
logging
.
getLogger
(
'autotvm'
)
class
HashMismatchError
(
ValueError
):
class
HashMismatchError
(
ValueError
):
"""Raised when the code hash of a submitted config doesn't match that on the
"""Raised when the code hash of a submitted config doesn't match that on the
...
@@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
...
@@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
If is none, will use environment variable "TVM_TRACKER_HOST"
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
and "TVM_TRACKER_PORT"
priority: int, optional
priority: int, optional
priority of this request, larger is more prior
The
priority of this request, larger is more prior
timeout: float, optional
timeout: float, optional
timeout of this session (units: seconds)
The
timeout of this session (units: seconds)
Returns
Returns
------
------
...
@@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
...
@@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
session_timeout
=
timeout
)
session_timeout
=
timeout
)
return
remote
return
remote
def
check_remote
(
target
,
device_key
,
tracker_addr
=
None
,
priority
=
2
,
timeout
=
10
):
"""
Check the availability of a remote device
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimerError will be raised.
"""
def
_check
():
remote
=
request_remote
(
device_key
,
tracker_addr
,
priority
)
remote
.
context
(
str
(
target
))
t
=
threading
.
Thread
(
target
=
_check
,)
t
.
start
()
t
.
join
(
timeout
)
return
not
t
.
is_alive
()
def
create_measure_batch
(
task
,
option
):
def
create_measure_batch
(
task
,
option
):
"""Get a standard measure_batch function.
"""Get a standard measure_batch function.
...
@@ -115,6 +144,17 @@ def create_measure_batch(task, option):
...
@@ -115,6 +144,17 @@ def create_measure_batch(task, option):
build_func
=
default_build_func
build_func
=
default_build_func
build_kwargs
[
'use_ndk'
]
=
True
build_kwargs
[
'use_ndk'
]
=
True
# check the availability of remote devices
if
hasattr
(
measure_func
,
'rpc_info'
):
rpc_info
=
measure_func
.
rpc_info
if
check_remote
(
task
.
target
,
rpc_info
[
'key'
],
(
rpc_info
[
'host'
],
rpc_info
[
'port'
])):
logger
.
info
(
"Get devices for measurement successfully!"
)
else
:
raise
RuntimeError
(
"Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status."
)
# add device info of cuda and opencl target
# add device info of cuda and opencl target
if
(
'cuda'
in
task
.
target
.
keys
or
'opencl'
in
task
.
target
.
keys
)
\
if
(
'cuda'
in
task
.
target
.
keys
or
'opencl'
in
task
.
target
.
keys
)
\
and
hasattr
(
measure_func
,
'rpc_info'
):
and
hasattr
(
measure_func
,
'rpc_info'
):
...
@@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
...
@@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
continue
continue
except
InstantiationError
as
e
:
except
InstantiationError
as
e
:
tstamp
=
time
.
time
()
tstamp
=
time
.
time
()
res_pack
.
append
(
MeasureResult
((
e
,),
res_pack
.
append
(
MeasureResult
((
InstantiationError
(
str
(
e
))
,),
MeasureErrorNo
.
INSTANTIATION_ERROR
,
MeasureErrorNo
.
INSTANTIATION_ERROR
,
tstamp
-
tic
,
tstamp
))
tstamp
-
tic
,
tstamp
))
continue
continue
...
@@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
...
@@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if
ref_output
:
if
ref_output
:
for
expected
,
real
in
zip
(
ref_output
,
args
):
for
expected
,
real
in
zip
(
ref_output
,
args
):
if
not
np
.
allclose
(
expected
,
real
.
asnumpy
(),
rtol
=
1e-4
):
if
not
np
.
allclose
(
expected
,
real
.
asnumpy
(),
rtol
=
1e-4
):
logg
ing
.
warning
(
"Wrong Answer!"
)
logg
er
.
warning
(
"Wrong Answer!"
)
errno
=
MeasureErrorNo
.
WRONG_ANSWER
errno
=
MeasureErrorNo
.
WRONG_ANSWER
except
TVMError
as
exc
:
except
TVMError
as
exc
:
msg
=
str
(
exc
)
msg
=
str
(
exc
)
...
...
python/tvm/autotvm/record.py
View file @
136061dc
...
@@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest
...
@@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest
from
.measure
import
MeasureInput
,
MeasureResult
from
.measure
import
MeasureInput
,
MeasureResult
AUTOTVM_LOG_VERSION
=
0.1
AUTOTVM_LOG_VERSION
=
0.1
logger
=
logging
.
getLogger
(
'autotvm'
)
try
:
# convert unicode to str for python2
try
:
# convert unicode to str for python2
_unicode
=
unicode
_unicode
=
unicode
...
@@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
...
@@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
tic
=
time
.
time
()
tic
=
time
.
time
()
lines
=
list
(
open
(
in_file
)
.
readlines
())
lines
=
list
(
open
(
in_file
)
.
readlines
())
logg
ing
.
info
(
"start converting..."
)
logg
er
.
info
(
"start converting..."
)
pool
=
multiprocessing
.
Pool
()
pool
=
multiprocessing
.
Pool
()
lines
=
pool
.
map
(
decode
,
lines
)
lines
=
pool
.
map
(
decode
,
lines
)
logg
ing
.
info
(
"map done
%.2
f"
,
time
.
time
()
-
tic
)
logg
er
.
info
(
"map done
%.2
f"
,
time
.
time
()
-
tic
)
wkl_dict
=
OrderedDict
()
wkl_dict
=
OrderedDict
()
for
inp
,
res
in
lines
:
for
inp
,
res
in
lines
:
...
@@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
...
@@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
cleaned
.
append
([
inp
,
res
])
cleaned
.
append
([
inp
,
res
])
# write to file
# write to file
logg
ing
.
info
(
"Key:
%
s
\t
Valid:
%
d
\t
Dup:
%
d
\t
"
,
k
,
len
(
cleaned
),
len
(
v
)
-
len
(
cleaned
))
logg
er
.
info
(
"Key:
%
s
\t
Valid:
%
d
\t
Dup:
%
d
\t
"
,
k
,
len
(
cleaned
),
len
(
v
)
-
len
(
cleaned
))
with
open
(
args
.
i
+
".
%03
d.wkl"
%
i
,
'w'
)
as
fout
:
with
open
(
args
.
i
+
".
%03
d.wkl"
%
i
,
'w'
)
as
fout
:
for
inp
,
res
in
cleaned
:
for
inp
,
res
in
cleaned
:
fout
.
write
(
encode
(
inp
,
res
)
+
'
\n
'
)
fout
.
write
(
encode
(
inp
,
res
)
+
'
\n
'
)
else
:
else
:
for
i
,
(
k
,
v
)
in
enumerate
(
wkl_dict
.
items
()):
for
i
,
(
k
,
v
)
in
enumerate
(
wkl_dict
.
items
()):
logg
ing
.
info
(
"Key:
%
s
\t
Num:
%
d"
,
k
,
len
(
v
))
logg
er
.
info
(
"Key:
%
s
\t
Num:
%
d"
,
k
,
len
(
v
))
with
open
(
args
.
i
+
".
%03
d.wkl"
%
i
,
'w'
)
as
fout
:
with
open
(
args
.
i
+
".
%03
d.wkl"
%
i
,
'w'
)
as
fout
:
for
inp
,
res
in
v
:
for
inp
,
res
in
v
:
fout
.
write
(
encode
(
inp
,
res
)
+
'
\n
'
)
fout
.
write
(
encode
(
inp
,
res
)
+
'
\n
'
)
...
@@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
...
@@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
for
v
in
best_context
.
best_by_targetkey
.
values
():
for
v
in
best_context
.
best_by_targetkey
.
values
():
best_set
.
add
(
measure_str_key
(
v
[
0
]))
best_set
.
add
(
measure_str_key
(
v
[
0
]))
logg
ing
.
info
(
"Extract
%
d best records from the
%
s"
,
len
(
best_set
),
in_file
)
logg
er
.
info
(
"Extract
%
d best records from the
%
s"
,
len
(
best_set
),
in_file
)
fout
=
open
(
out_file
,
'w'
)
if
isinstance
(
out_file
,
str
)
else
out_file
fout
=
open
(
out_file
,
'w'
)
if
isinstance
(
out_file
,
str
)
else
out_file
for
inp
,
res
in
load_from_file
(
in_file
):
for
inp
,
res
in
load_from_file
(
in_file
):
...
@@ -270,7 +271,7 @@ if __name__ == '__main__':
...
@@ -270,7 +271,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--code"
,
action
=
'store_true'
)
parser
.
add_argument
(
"--code"
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logg
ing
.
basicConfig
(
level
=
logging
.
INFO
)
logg
er
.
basicConfig
(
level
=
logger
.
INFO
)
if
args
.
mode
==
'pick'
:
if
args
.
mode
==
'pick'
:
args
.
o
=
args
.
o
or
args
.
i
+
".best.log"
args
.
o
=
args
.
o
or
args
.
i
+
".best.log"
...
...
python/tvm/autotvm/task/dispatcher.py
View file @
136061dc
...
@@ -10,6 +10,8 @@ of the DispatchContext base class.
...
@@ -10,6 +10,8 @@ of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner.
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
- During evaluation, we can use it to set pick the best policy.
"""
"""
# pylint: disable=invalid-name
from
__future__
import
absolute_import
as
_abs
from
__future__
import
absolute_import
as
_abs
import
logging
import
logging
...
@@ -19,6 +21,8 @@ import numpy as np
...
@@ -19,6 +21,8 @@ import numpy as np
from
tvm
import
target
as
_target
from
tvm
import
target
as
_target
logger
=
logging
.
getLogger
(
'autotvm'
)
class
DispatchContext
(
object
):
class
DispatchContext
(
object
):
"""
"""
Base class of dispatch context.
Base class of dispatch context.
...
@@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext):
...
@@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext):
best_by_model
[
key
]
=
(
inp
,
res
)
best_by_model
[
key
]
=
(
inp
,
res
)
break
break
logg
ing
.
debug
(
"Finish loading
%
d records"
,
counter
)
logg
er
.
debug
(
"Finish loading
%
d records"
,
counter
)
def
query
(
self
,
target
,
workload
):
def
query
(
self
,
target
,
workload
):
if
target
is
None
:
if
target
is
None
:
...
...
python/tvm/autotvm/tophub.py
View file @
136061dc
...
@@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi
...
@@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time.
TVM will download these parameters for you when you create the target for the first time.
"""
"""
# pylint: disable=invalid-name
import
logging
import
logging
import
os
import
os
...
@@ -16,6 +17,7 @@ from ..contrib.download import download
...
@@ -16,6 +17,7 @@ from ..contrib.download import download
AUTOTVM_TOPHUB_ROOT_PATH
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
".tvm"
,
"tophub"
)
AUTOTVM_TOPHUB_ROOT_PATH
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
".tvm"
,
"tophub"
)
logger
=
logging
.
getLogger
(
'autotvm'
)
def
_alias
(
name
):
def
_alias
(
name
):
"""convert alias for some packages"""
"""convert alias for some packages"""
...
@@ -79,7 +81,7 @@ def download_package(backend):
...
@@ -79,7 +81,7 @@ def download_package(backend):
os
.
mkdir
(
path
)
os
.
mkdir
(
path
)
backend
=
_alias
(
backend
)
backend
=
_alias
(
backend
)
logg
ing
.
info
(
"Download pre-tuned parameters for
%
s"
,
backend
)
logg
er
.
info
(
"Download pre-tuned parameters for
%
s"
,
backend
)
download
(
"https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/
%
s.log"
%
backend
,
download
(
"https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/
%
s.log"
%
backend
,
os
.
path
.
join
(
rootpath
,
backend
+
".log"
),
True
,
verbose
=
0
)
os
.
path
.
join
(
rootpath
,
backend
+
".log"
),
True
,
verbose
=
0
)
...
@@ -110,7 +112,7 @@ def list_packages():
...
@@ -110,7 +112,7 @@ def list_packages():
"""
"""
path
=
tempdir
()
path
=
tempdir
()
filename
=
path
.
relpath
(
"info.json"
)
filename
=
path
.
relpath
(
"info.json"
)
logg
ing
.
info
(
"Download meta info for pre-tuned parameters"
)
logg
er
.
info
(
"Download meta info for pre-tuned parameters"
)
download
(
"https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json"
,
download
(
"https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json"
,
filename
,
True
,
verbose
=
0
)
filename
,
True
,
verbose
=
0
)
...
...
python/tvm/autotvm/tuner/callback.py
View file @
136061dc
...
@@ -2,11 +2,13 @@
...
@@ -2,11 +2,13 @@
"""Namespace of callback utilities of AutoTVM"""
"""Namespace of callback utilities of AutoTVM"""
import
sys
import
sys
import
time
import
time
import
logging
import
numpy
as
np
import
numpy
as
np
from
..
import
record
from
..
import
record
logger
=
logging
.
getLogger
(
'autotvm'
)
def
log_to_file
(
file_out
,
protocol
=
'json'
):
def
log_to_file
(
file_out
,
protocol
=
'json'
):
"""Log the tuning records into file.
"""Log the tuning records into file.
...
@@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
...
@@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
prefix: str
prefix: str
The prefix of output message
The prefix of output message
"""
"""
class
_Context
:
class
_Context
(
object
)
:
"""Context to store local variables"""
"""Context to store local variables"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
best_flops
=
0
self
.
best_flops
=
0
...
@@ -112,11 +114,12 @@ def progress_bar(total, prefix=''):
...
@@ -112,11 +114,12 @@ def progress_bar(total, prefix=''):
if
res
.
error_no
==
0
:
if
res
.
error_no
==
0
:
flops
=
inp
.
task
.
flop
/
np
.
mean
(
res
.
costs
)
flops
=
inp
.
task
.
flop
/
np
.
mean
(
res
.
costs
)
if
logger
.
level
<
logging
.
DEBUG
:
# only print progress bar in non-debug mode
ctx
.
cur_flops
=
flops
ctx
.
cur_flops
=
flops
ctx
.
best_flops
=
tuner
.
best_flops
ctx
.
best_flops
=
tuner
.
best_flops
sys
.
stdout
.
write
(
'
\r
%
s Current/Best:
%7.2
f/
%7.2
f GFLOPS | Progress: (
%
d/
%
d) '
sys
.
stdout
.
write
(
'
%
s Current/Best:
%7.2
f/
%7.2
f GFLOPS | Progress: (
%
d/
%
d) '
'|
%.2
f s
'
%
'|
%.2
f s
\r
'
%
(
prefix
,
ctx
.
cur_flops
/
1e9
,
ctx
.
best_flops
/
1e9
,
ctx
.
ct
,
ctx
.
total
,
(
prefix
,
ctx
.
cur_flops
/
1e9
,
ctx
.
best_flops
/
1e9
,
ctx
.
ct
,
ctx
.
total
,
time
.
time
()
-
tic
))
time
.
time
()
-
tic
))
sys
.
stdout
.
flush
()
sys
.
stdout
.
flush
()
...
...
python/tvm/autotvm/tuner/sa_model_optimizer.py
View file @
136061dc
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate
, invalid-name
"""
"""
Cost model optimizer based on simulated annealing
Cost model optimizer based on simulated annealing
"""
"""
...
@@ -12,6 +12,8 @@ import numpy as np
...
@@ -12,6 +12,8 @@ import numpy as np
from
..util
import
sample_ints
from
..util
import
sample_ints
from
.model_based_tuner
import
ModelOptimizer
,
knob2point
,
point2knob
from
.model_based_tuner
import
ModelOptimizer
,
knob2point
,
point2knob
logger
=
logging
.
getLogger
(
'autotvm'
)
class
SimulatedAnnealingOptimizer
(
ModelOptimizer
):
class
SimulatedAnnealingOptimizer
(
ModelOptimizer
):
"""parallel simulated annealing optimization algorithm
"""parallel simulated annealing optimization algorithm
...
@@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
...
@@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
if
log_interval
and
k
%
log_interval
==
0
:
if
log_interval
and
k
%
log_interval
==
0
:
t_str
=
"
%.2
f"
%
t
t_str
=
"
%.2
f"
%
t
logg
ing
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
temp:
%
s
\t
"
logg
er
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
temp:
%
s
\t
"
"elapsed:
%.2
f"
,
"elapsed:
%.2
f"
,
k
,
k_last_modify
,
heap_items
[
0
][
0
],
k
,
k_last_modify
,
heap_items
[
0
][
0
],
np
.
max
([
v
for
v
,
_
in
heap_items
]),
t_str
,
np
.
max
([
v
for
v
,
_
in
heap_items
]),
t_str
,
time
.
time
()
-
tic
)
time
.
time
()
-
tic
)
heap_items
.
sort
(
key
=
lambda
item
:
-
item
[
0
])
heap_items
.
sort
(
key
=
lambda
item
:
-
item
[
0
])
logg
ing
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
elapsed:
%.2
f"
,
logg
er
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
elapsed:
%.2
f"
,
k
,
k_last_modify
,
heap_items
[
-
1
][
0
],
heap_items
[
0
][
0
],
time
.
time
()
-
tic
)
k
,
k_last_modify
,
heap_items
[
-
1
][
0
],
heap_items
[
0
][
0
],
time
.
time
()
-
tic
)
logg
ing
.
debug
(
"SA Maximums:
%
s"
,
heap_items
)
logg
er
.
debug
(
"SA Maximums:
%
s"
,
heap_items
)
if
self
.
persistent
:
if
self
.
persistent
:
self
.
points
=
points
self
.
points
=
points
...
...
python/tvm/autotvm/tuner/tuner.py
View file @
136061dc
...
@@ -4,11 +4,12 @@ import logging
...
@@ -4,11 +4,12 @@ import logging
import
numpy
as
np
import
numpy
as
np
from
..measure
import
MeasureInput
from
..measure
import
MeasureInput
,
create_measure_batch
from
..measure
import
create_measure_batch
from
..env
import
GLOBAL_SCOPE
from
..env
import
GLOBAL_SCOPE
logger
=
logging
.
getLogger
(
'autotvm'
)
class
Tuner
(
object
):
class
Tuner
(
object
):
"""Base class for tuners
"""Base class for tuners
...
@@ -86,9 +87,10 @@ class Tuner(object):
...
@@ -86,9 +87,10 @@ class Tuner(object):
measure_batch
=
create_measure_batch
(
self
.
task
,
measure_option
)
measure_batch
=
create_measure_batch
(
self
.
task
,
measure_option
)
parallel_num
=
getattr
(
measure_batch
,
'parallel_num'
,
1
)
parallel_num
=
getattr
(
measure_batch
,
'parallel_num'
,
1
)
early_stopping
=
early_stopping
or
1e9
early_stopping
=
early_stopping
or
1e9
old_level
=
logger
.
level
GLOBAL_SCOPE
.
in_tuning
=
True
GLOBAL_SCOPE
.
in_tuning
=
True
i
=
0
i
=
error_ct
=
0
while
i
<
n_trial
:
while
i
<
n_trial
:
if
not
self
.
has_next
():
if
not
self
.
has_next
():
break
break
...
@@ -103,15 +105,18 @@ class Tuner(object):
...
@@ -103,15 +105,18 @@ class Tuner(object):
config
=
inp
.
config
config
=
inp
.
config
if
res
.
error_no
==
0
:
if
res
.
error_no
==
0
:
flops
=
inp
.
task
.
flop
/
np
.
mean
(
res
.
costs
)
flops
=
inp
.
task
.
flop
/
np
.
mean
(
res
.
costs
)
error_ct
=
0
else
:
else
:
flops
=
0
flops
=
0
error_ct
+=
1
if
flops
>
self
.
best_flops
:
if
flops
>
self
.
best_flops
:
self
.
best_flops
=
flops
self
.
best_flops
=
flops
self
.
best_config
=
config
self
.
best_config
=
config
self
.
best_measure_pair
=
(
inp
,
res
)
self
.
best_measure_pair
=
(
inp
,
res
)
self
.
best_iter
=
i
+
k
self
.
best_iter
=
i
+
k
logg
ing
.
debug
(
"No:
%
d
\t
GFLOPS:
%.2
f/
%.2
f
\t
result:
%
s
\t
%
s"
,
logg
er
.
debug
(
"No:
%
d
\t
GFLOPS:
%.2
f/
%.2
f
\t
result:
%
s
\t
%
s"
,
i
+
k
+
1
,
flops
/
1e9
,
self
.
best_flops
/
1e9
,
i
+
k
+
1
,
flops
/
1e9
,
self
.
best_flops
/
1e9
,
res
,
config
)
res
,
config
)
...
@@ -123,11 +128,16 @@ class Tuner(object):
...
@@ -123,11 +128,16 @@ class Tuner(object):
callback
(
self
,
inputs
,
results
)
callback
(
self
,
inputs
,
results
)
if
i
>
self
.
best_iter
+
early_stopping
:
if
i
>
self
.
best_iter
+
early_stopping
:
logg
ing
.
debug
(
"Early stopped. Best iter:
%
d."
,
self
.
best_iter
)
logg
er
.
debug
(
"Early stopped. Best iter:
%
d."
,
self
.
best_iter
)
break
break
GLOBAL_SCOPE
.
in_tuning
=
False
if
error_ct
>
50
:
logger
.
warning
(
"Too many errors happen in the tuning. Now is in debug mode"
)
logger
.
setLevel
(
logging
.
DEBUG
)
else
:
logger
.
setLevel
(
old_level
)
GLOBAL_SCOPE
.
in_tuning
=
False
del
measure_batch
del
measure_batch
def
reset
(
self
):
def
reset
(
self
):
...
...
python/tvm/autotvm/tuner/xgboost_cost_model.py
View file @
136061dc
...
@@ -16,6 +16,8 @@ from ..util import get_rank
...
@@ -16,6 +16,8 @@ from ..util import get_rank
from
.metric
import
max_curve
,
recall_curve
,
cover_curve
from
.metric
import
max_curve
,
recall_curve
,
cover_curve
from
.model_based_tuner
import
CostModel
,
FeatureCache
from
.model_based_tuner
import
CostModel
,
FeatureCache
logger
=
logging
.
getLogger
(
'autotvm'
)
class
XGBoostCostModel
(
CostModel
):
class
XGBoostCostModel
(
CostModel
):
"""XGBoost as cost model
"""XGBoost as cost model
...
@@ -163,7 +165,7 @@ class XGBoostCostModel(CostModel):
...
@@ -163,7 +165,7 @@ class XGBoostCostModel(CostModel):
],
],
verbose_eval
=
self
.
log_interval
)])
verbose_eval
=
self
.
log_interval
)])
logg
ing
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d
\t
error:
%
d
\t
n_cache:
%
d"
,
logg
er
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d
\t
error:
%
d
\t
n_cache:
%
d"
,
time
.
time
()
-
tic
,
len
(
xs
),
time
.
time
()
-
tic
,
len
(
xs
),
len
(
xs
)
-
np
.
sum
(
valid_index
),
len
(
xs
)
-
np
.
sum
(
valid_index
),
self
.
feature_cache
.
size
(
self
.
fea_type
))
self
.
feature_cache
.
size
(
self
.
fea_type
))
...
@@ -173,7 +175,7 @@ class XGBoostCostModel(CostModel):
...
@@ -173,7 +175,7 @@ class XGBoostCostModel(CostModel):
self
.
_reset_pool
()
self
.
_reset_pool
()
args
=
list
(
records
)
args
=
list
(
records
)
logg
ing
.
debug
(
"XGB load
%
d entries from history log file"
,
len
(
args
))
logg
er
.
debug
(
"XGB load
%
d entries from history log file"
,
len
(
args
))
if
self
.
fea_type
==
'itervar'
:
if
self
.
fea_type
==
'itervar'
:
feature_extract_func
=
_extract_itervar_feature_log
feature_extract_func
=
_extract_itervar_feature_log
...
@@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel):
...
@@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel):
],
],
verbose_eval
=
self
.
log_interval
)])
verbose_eval
=
self
.
log_interval
)])
logg
ing
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d"
,
time
.
time
()
-
tic
,
len
(
xs
))
logg
er
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d"
,
time
.
time
()
-
tic
,
len
(
xs
))
def
predict
(
self
,
xs
,
output_margin
=
False
):
def
predict
(
self
,
xs
,
output_margin
=
False
):
feas
=
self
.
_get_feature
(
xs
)
feas
=
self
.
_get_feature
(
xs
)
...
@@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
...
@@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
infos
.
append
(
"
%
s:
%.6
f"
%
(
item
[
0
],
item
[
1
]))
infos
.
append
(
"
%
s:
%.6
f"
%
(
item
[
0
],
item
[
1
]))
if
not
isinstance
(
verbose_eval
,
bool
)
and
verbose_eval
and
i
%
verbose_eval
==
0
:
if
not
isinstance
(
verbose_eval
,
bool
)
and
verbose_eval
and
i
%
verbose_eval
==
0
:
logg
ing
.
debug
(
"
\t
"
.
join
(
infos
))
logg
er
.
debug
(
"
\t
"
.
join
(
infos
))
if
log_file
:
if
log_file
:
with
open
(
log_file
,
"a"
)
as
fout
:
with
open
(
log_file
,
"a"
)
as
fout
:
fout
.
write
(
"
\t
"
.
join
(
infos
)
+
'
\n
'
)
fout
.
write
(
"
\t
"
.
join
(
infos
)
+
'
\n
'
)
...
@@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
...
@@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
elif
env
.
iteration
-
best_iteration
>=
stopping_rounds
:
elif
env
.
iteration
-
best_iteration
>=
stopping_rounds
:
best_msg
=
state
[
'best_msg'
]
best_msg
=
state
[
'best_msg'
]
if
verbose_eval
and
env
.
rank
==
0
:
if
verbose_eval
and
env
.
rank
==
0
:
logg
ing
.
debug
(
"XGB stopped. Best iteration:
%
s "
,
best_msg
)
logg
er
.
debug
(
"XGB stopped. Best iteration:
%
s "
,
best_msg
)
raise
EarlyStopException
(
best_iteration
)
raise
EarlyStopException
(
best_iteration
)
return
callback
return
callback
...
...
python/tvm/autotvm/util.py
View file @
136061dc
...
@@ -8,6 +8,7 @@ import numpy as np
...
@@ -8,6 +8,7 @@ import numpy as np
from
..
import
expr
,
ir_pass
from
..
import
expr
,
ir_pass
logger
=
logging
.
getLogger
(
'autotvm'
)
class
EmptyContext
(
object
):
class
EmptyContext
(
object
):
"""An empty context"""
"""An empty context"""
...
@@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
...
@@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
tic
=
time
.
time
()
tic
=
time
.
time
()
local_pool
=
pool
or
multiprocessing
.
Pool
()
local_pool
=
pool
or
multiprocessing
.
Pool
()
if
verbose
:
if
verbose
:
logg
ing
.
info
(
"mapping begin"
)
logg
er
.
info
(
"mapping begin"
)
for
i
in
range
(
0
,
len
(
args
),
batch_size
):
for
i
in
range
(
0
,
len
(
args
),
batch_size
):
if
verbose
:
if
verbose
:
logg
ing
.
info
(
"mapping
%
d/
%
d elapsed
%.2
f"
,
i
,
len
(
args
),
logg
er
.
info
(
"mapping
%
d/
%
d elapsed
%.2
f"
,
i
,
len
(
args
),
time
.
time
()
-
tic
)
time
.
time
()
-
tic
)
tmp
=
np
.
array
(
local_pool
.
map
(
func
,
args
[
i
:
i
+
batch_size
]))
tmp
=
np
.
array
(
local_pool
.
map
(
func
,
args
[
i
:
i
+
batch_size
]))
ret
=
tmp
if
ret
is
None
else
np
.
concatenate
((
ret
,
tmp
))
ret
=
tmp
if
ret
is
None
else
np
.
concatenate
((
ret
,
tmp
))
if
verbose
:
if
verbose
:
logg
ing
.
info
(
"mapping done"
)
logg
er
.
info
(
"mapping done"
)
if
not
pool
:
if
not
pool
:
local_pool
.
close
()
local_pool
.
close
()
return
ret
return
ret
...
...
python/tvm/rpc/base.py
View file @
136061dc
"""Base definitions for RPC."""
"""Base definitions for RPC."""
# pylint: disable=invalid-name
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
socket
import
socket
...
@@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
...
@@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
# cannot found matched key in server
RPC_CODE_MISMATCH
=
RPC_MAGIC
+
2
RPC_CODE_MISMATCH
=
RPC_MAGIC
+
2
logger
=
logging
.
getLogger
(
'RPCServer'
)
class
TrackerCode
(
object
):
class
TrackerCode
(
object
):
"""Enumeration code for the RPC tracker"""
"""Enumeration code for the RPC tracker"""
...
@@ -120,7 +123,7 @@ def random_key(prefix, cmap=None):
...
@@ -120,7 +123,7 @@ def random_key(prefix, cmap=None):
return
prefix
+
str
(
random
.
random
())
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
,
silent
=
False
):
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
"""Connect to a TPC address with retry
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
This function is only reliable to short period of server restart.
...
@@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
...
@@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
retry_period : float
retry_period : float
Number of seconds before we retry again.
Number of seconds before we retry again.
silent: bool
whether run in silent mode
"""
"""
tstart
=
time
.
time
()
tstart
=
time
.
time
()
while
True
:
while
True
:
...
@@ -152,8 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
...
@@ -152,8 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
if
period
>
timeout
:
if
period
>
timeout
:
raise
RuntimeError
(
raise
RuntimeError
(
"Failed to connect to server
%
s"
%
str
(
addr
))
"Failed to connect to server
%
s"
%
str
(
addr
))
if
not
silent
:
logger
.
warning
(
"Cannot connect to tracker
%
s, retry in
%
g secs..."
,
logging
.
info
(
"Cannot connect to tracker
%
s, retry in
%
g secs..."
,
str
(
addr
),
retry_period
)
str
(
addr
),
retry_period
)
time
.
sleep
(
retry_period
)
time
.
sleep
(
retry_period
)
...
...
python/tvm/rpc/proxy.py
View file @
136061dc
...
@@ -23,7 +23,8 @@ try:
...
@@ -23,7 +23,8 @@ try:
from
tornado
import
ioloop
from
tornado
import
ioloop
from
.
import
tornado_util
from
.
import
tornado_util
except
ImportError
as
error_msg
:
except
ImportError
as
error_msg
:
raise
ImportError
(
"RPCProxy module requires tornado package
%
s"
%
error_msg
)
raise
ImportError
(
"RPCProxy module requires tornado package
%
s. Try 'pip install tornado'."
%
error_msg
)
from
.
import
base
from
.
import
base
from
.base
import
TrackerCode
from
.base
import
TrackerCode
...
@@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""):
...
@@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""):
def
_connect
(
key
):
def
_connect
(
key
):
conn
=
yield
websocket
.
websocket_connect
(
url
)
conn
=
yield
websocket
.
websocket_connect
(
url
)
on_message
=
create_on_message
(
conn
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
,
None
)
temp
=
_server_env
(
None
)
# Start connecton
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'<i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'<i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
key
=
"server:"
+
key
...
...
python/tvm/rpc/server.py
View file @
136061dc
...
@@ -8,6 +8,8 @@ Server is TCP based with the following protocol:
...
@@ -8,6 +8,8 @@ Server is TCP based with the following protocol:
- The key is in format
- The key is in format
- {server|client}:device-type[:random-key] [-timeout=timeout]
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
"""
# pylint: disable=invalid-name
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
os
import
os
...
@@ -30,11 +32,11 @@ from ..contrib import util
...
@@ -30,11 +32,11 @@ from ..contrib import util
from
.
import
base
from
.
import
base
from
.
base
import
TrackerCode
from
.
base
import
TrackerCode
def
_server_env
(
load_library
,
logger
):
logger
=
logging
.
getLogger
(
'RPCServer'
)
def
_server_env
(
load_library
):
"""Server environment function return temp dir"""
"""Server environment function return temp dir"""
temp
=
util
.
tempdir
()
temp
=
util
.
tempdir
()
if
logger
is
None
:
logger
=
logging
.
getLogger
()
# pylint: disable=unused-variable
# pylint: disable=unused-variable
@register_func
(
"tvm.rpc.server.workpath"
)
@register_func
(
"tvm.rpc.server.workpath"
)
...
@@ -59,13 +61,10 @@ def _server_env(load_library, logger):
...
@@ -59,13 +61,10 @@ def _server_env(load_library, logger):
return
temp
return
temp
def
_serve_loop
(
sock
,
addr
,
load_library
,
silent
):
def
_serve_loop
(
sock
,
addr
,
load_library
):
"""Server loop"""
"""Server loop"""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
sockfd
=
sock
.
fileno
()
sockfd
=
sock
.
fileno
()
temp
=
_server_env
(
load_library
,
logger
)
temp
=
_server_env
(
load_library
)
base
.
_ServerLoop
(
sockfd
)
base
.
_ServerLoop
(
sockfd
)
temp
.
remove
()
temp
.
remove
()
logger
.
info
(
"Finish serving
%
s"
,
addr
)
logger
.
info
(
"Finish serving
%
s"
,
addr
)
...
@@ -79,12 +78,8 @@ def _parse_server_opt(opts):
...
@@ -79,12 +78,8 @@ def _parse_server_opt(opts):
ret
[
"timeout"
]
=
float
(
kv
[
9
:])
ret
[
"timeout"
]
=
float
(
kv
[
9
:])
return
ret
return
ret
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
,
silent
):
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
):
"""Listening loop of the server master."""
"""Listening loop of the server master."""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
"""Accept connection from the other places.
"""Accept connection from the other places.
...
@@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
...
@@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
if
arr
[
0
]
!=
expect_header
:
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
conn
.
close
()
logger
.
info
(
"mismatch key from
%
s"
,
addr
)
logger
.
warning
(
"mismatch key from
%
s"
,
addr
)
continue
continue
else
:
else
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_SUCCESS
))
...
@@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
...
@@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
try
:
try
:
# step 1: setup tracker and report to tracker
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
,
silent
=
silent
)
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_TRACKER_MAGIC
))
tracker_conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
...
@@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
...
@@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
tracker_conn
=
None
tracker_conn
=
None
continue
continue
except
RuntimeError
as
exc
:
except
RuntimeError
as
exc
:
if
silent
:
return
else
:
raise
exc
raise
exc
# step 3: serving
# step 3: serving
logger
.
info
(
"connection from
%
s"
,
addr
)
logger
.
info
(
"connection from
%
s"
,
addr
)
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
,
load_library
,
silent
))
args
=
(
conn
,
addr
,
load_library
))
server_proc
.
deamon
=
True
server_proc
.
deamon
=
True
server_proc
.
start
()
server_proc
.
start
()
# close from our side.
# close from our side.
...
@@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
...
@@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
server_proc
.
terminate
()
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
,
load_library
,
silent
):
def
_connect_proxy_loop
(
addr
,
key
,
load_library
):
logger
=
logging
.
getLogger
(
"RPCProxy"
)
if
silent
:
logger
.
disabled
=
True
key
=
"server:"
+
key
key
=
"server:"
+
key
retry_count
=
0
retry_count
=
0
max_retry
=
5
max_retry
=
5
...
@@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
...
@@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logger
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
logger
.
warning
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
sock
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
sock
,
4
))[
0
]
...
@@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
...
@@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logger
.
info
(
"connected to
%
s"
,
str
(
addr
))
logger
.
info
(
"connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
,
silent
))
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
))
process
.
deamon
=
True
process
.
deamon
=
True
process
.
start
()
process
.
start
()
sock
.
close
()
sock
.
close
()
...
@@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
...
@@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
retry_count
=
0
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
retry_count
+=
1
logger
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
logger
.
warning
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
time
.
sleep
(
retry_period
)
...
@@ -323,9 +312,8 @@ class Server(object):
...
@@ -323,9 +312,8 @@ class Server(object):
self
.
custom_addr
=
custom_addr
self
.
custom_addr
=
custom_addr
self
.
use_popen
=
use_popen
self
.
use_popen
=
use_popen
self
.
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
if
silent
:
self
.
logger
.
disabled
=
True
logger
.
setLevel
(
logging
.
WARN
)
if
use_popen
:
if
use_popen
:
cmd
=
[
sys
.
executable
,
cmd
=
[
sys
.
executable
,
...
@@ -360,18 +348,18 @@ class Server(object):
...
@@ -360,18 +348,18 @@ class Server(object):
raise
sock_err
raise
sock_err
if
not
self
.
port
:
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
self
.
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
sock
.
listen
(
1
)
self
.
sock
=
sock
self
.
sock
=
sock
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
_listen_loop
,
args
=
(
target
=
_listen_loop
,
args
=
(
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
,
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
,
self
.
custom_addr
,
silent
))
self
.
custom_addr
))
self
.
proc
.
deamon
=
True
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
self
.
proc
.
start
()
else
:
else
:
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
,
silent
))
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
))
self
.
proc
.
deamon
=
True
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
self
.
proc
.
start
()
...
...
python/tvm/rpc/tracker.py
View file @
136061dc
...
@@ -23,6 +23,8 @@ List of available APIs:
...
@@ -23,6 +23,8 @@ List of available APIs:
- input: [TrackerCode.REQUEST, [key, user, priority]]
- input: [TrackerCode.REQUEST, [key, user, priority]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
"""
# pylint: disable=invalid-name
import
heapq
import
heapq
import
time
import
time
import
logging
import
logging
...
@@ -37,12 +39,13 @@ try:
...
@@ -37,12 +39,13 @@ try:
from
.
import
tornado_util
from
.
import
tornado_util
except
ImportError
as
error_msg
:
except
ImportError
as
error_msg
:
raise
ImportError
(
raise
ImportError
(
"RPCTracker module requires tornado package
%
s"
%
error_msg
)
"RPCTracker module requires tornado package
%
s
. Try 'pip install tornado'.
"
%
error_msg
)
from
.._ffi.base
import
py_str
from
.._ffi.base
import
py_str
from
.
import
base
from
.
import
base
from
.base
import
RPC_TRACKER_MAGIC
,
TrackerCode
from
.base
import
RPC_TRACKER_MAGIC
,
TrackerCode
logger
=
logging
.
getLogger
(
"RPCTracker"
)
class
Scheduler
(
object
):
class
Scheduler
(
object
):
"""Abstratc interface of scheduler."""
"""Abstratc interface of scheduler."""
...
@@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
def
_init_conn
(
self
,
message
):
def
_init_conn
(
self
,
message
):
"""Initialie the connection"""
"""Initialie the connection"""
if
len
(
message
)
!=
4
:
if
len
(
message
)
!=
4
:
logg
ing
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
logg
er
.
warning
(
"Invalid connection from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
magic
=
struct
.
unpack
(
'<i'
,
message
)[
0
]
magic
=
struct
.
unpack
(
'<i'
,
message
)[
0
]
if
magic
!=
RPC_TRACKER_MAGIC
:
if
magic
!=
RPC_TRACKER_MAGIC
:
logg
ing
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
logg
er
.
warning
(
"Invalid magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
self
.
write_message
(
struct
.
pack
(
'<i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
write_message
(
struct
.
pack
(
'<i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
_init_req_nbytes
=
0
self
.
_init_req_nbytes
=
0
...
@@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler):
status
=
self
.
_tracker
.
summary
()
status
=
self
.
_tracker
.
summary
()
self
.
ret_value
([
TrackerCode
.
SUCCESS
,
status
])
self
.
ret_value
([
TrackerCode
.
SUCCESS
,
status
])
else
:
else
:
logg
ing
.
info
(
"Unknown tracker code
%
d"
,
code
)
logg
er
.
warning
(
"Unknown tracker code
%
d"
,
code
)
self
.
close
()
self
.
close
()
def
on_close
(
self
):
def
on_close
(
self
):
self
.
_tracker
.
_connections
.
remove
(
self
)
self
.
_tracker
.
_connections
.
remove
(
self
)
def
on_error
(
self
,
err
):
def
on_error
(
self
,
err
):
logg
ing
.
info
(
"
%
s: Error in RPC Tracker:
%
s"
,
self
.
name
(),
err
)
logg
er
.
warning
(
"
%
s: Error in RPC Tracker:
%
s"
,
self
.
name
(),
err
)
self
.
close
()
self
.
close
()
...
@@ -335,9 +338,8 @@ class Tracker(object):
...
@@ -335,9 +338,8 @@ class Tracker(object):
port
=
9190
,
port
=
9190
,
port_end
=
9199
,
port_end
=
9199
,
silent
=
False
):
silent
=
False
):
self
.
logger
=
logging
.
getLogger
(
"RPCTracker"
)
if
silent
:
if
silent
:
self
.
logger
.
disabled
=
True
logger
.
setLevel
(
logging
.
WARN
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
port
=
None
...
@@ -354,7 +356,7 @@ class Tracker(object):
...
@@ -354,7 +356,7 @@ class Tracker(object):
raise
sock_err
raise
sock_err
if
not
self
.
port
:
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
self
.
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
sock
.
listen
(
1
)
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
_tracker_server
,
args
=
(
sock
,
self
.
stop_key
))
target
=
_tracker_server
,
args
=
(
sock
,
self
.
stop_key
))
...
@@ -380,7 +382,7 @@ class Tracker(object):
...
@@ -380,7 +382,7 @@ class Tracker(object):
self
.
_stop_tracker
()
self
.
_stop_tracker
()
self
.
proc
.
join
(
1
)
self
.
proc
.
join
(
1
)
if
self
.
proc
.
is_alive
():
if
self
.
proc
.
is_alive
():
self
.
logger
.
info
(
"Terminating Tracker Server..."
)
logger
.
info
(
"Terminating Tracker Server..."
)
self
.
proc
.
terminate
()
self
.
proc
.
terminate
()
self
.
proc
=
None
self
.
proc
=
None
...
...
tutorials/autotvm/tune_conv2d_cuda.py
View file @
136061dc
...
@@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
...
@@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# for this template
# for this template
# logging config (for printing tuning log to screen)
# logging config (for printing tuning log to screen)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
stream
=
sys
.
stdout
)
logging
.
getLogger
(
'autotvm'
)
.
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'autotvm'
)
.
addHandler
(
logging
.
StreamHandler
(
sys
.
stdout
))
# the last layer in resnet
# the last layer in resnet
N
,
H
,
W
,
CO
,
CI
,
KH
,
KW
,
strides
,
padding
=
1
,
7
,
7
,
512
,
512
,
3
,
3
,
(
1
,
1
),
(
1
,
1
)
N
,
H
,
W
,
CO
,
CI
,
KH
,
KW
,
strides
,
padding
=
1
,
7
,
7
,
512
,
512
,
3
,
3
,
(
1
,
1
),
(
1
,
1
)
...
...
tutorials/autotvm/tune_nnvm_arm.py
View file @
136061dc
...
@@ -163,8 +163,10 @@ def get_network(name, batch_size):
...
@@ -163,8 +163,10 @@ def get_network(name, batch_size):
# Set Tuning Options
# Set Tuning Options
# ------------------
# ------------------
# Before tuning, we should do some configurations. Here I use an RK3399 board
# Before tuning, we should do some configurations. Here I use an RK3399 board
# in our environment as example. In your setting, you should modify the target
# as example. In your setting, you should modify the target and device_key accordingly.
# and device_key accordingly.
# set :code:`use_android` to True if you use android phone.
#### DEVICE CONFIG ####
# Replace "aarch64-linux-gnu" with the correct target of your board.
# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
...
@@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
...
@@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
# Also replace this with the device key in your tracker
# Also replace this with the device key in your tracker
device_key
=
'rk3399'
device_key
=
'rk3399'
# tuning option
# Set this to True if you use android phone
use_android
=
False
#### TUNING OPTION ####
network
=
'resnet-18'
network
=
'resnet-18'
log_file
=
"
%
s.
%
s.log"
%
(
device_key
,
network
)
log_file
=
"
%
s.
%
s.log"
%
(
device_key
,
network
)
dtype
=
'float32'
dtype
=
'float32'
...
@@ -181,17 +186,17 @@ dtype = 'float32'
...
@@ -181,17 +186,17 @@ dtype = 'float32'
tuning_option
=
{
tuning_option
=
{
'log_filename'
:
log_file
,
'log_filename'
:
log_file
,
'tuner'
:
'xgb'
,
'tuner'
:
'xgb'
,
'n_trial'
:
1000
,
'n_trial'
:
1000
,
'early_stopping'
:
2
0
0
,
'early_stopping'
:
2
5
0
,
'measure_option'
:
autotvm
.
measure_option
(
'measure_option'
:
autotvm
.
measure_option
(
autotvm
.
use_rpc
(
device_key
,
host
=
'localhost'
,
port
=
9190
),
autotvm
.
use_rpc
(
device_key
,
host
=
'localhost'
,
port
=
9190
),
number
=
4
,
number
=
4
,
parallel_num
=
1
,
parallel_num
=
1
,
timeout
=
10
)
,
timeout
=
10
,
build_func
=
'ndk'
if
use_android
else
'default'
,
'use_transfer_learning'
:
True
,
)
,
}
}
####################################################################
####################################################################
...
@@ -208,9 +213,6 @@ tuning_option = {
...
@@ -208,9 +213,6 @@ tuning_option = {
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
# consider setting timeout larger.
#
#
# **For android phone**, add :code:`build_func='ndk'` to the argument list of
# :code:`autotvm.measure_option` to use Android NDK for creating shared library.
#
###################################################################
###################################################################
# Begin Tuning
# Begin Tuning
...
@@ -280,12 +282,14 @@ def tune_tasks(tasks,
...
@@ -280,12 +282,14 @@ def tune_tasks(tasks,
def
tune_and_evaluate
():
def
tune_and_evaluate
():
# extract workloads from nnvm graph
# extract workloads from nnvm graph
print
(
"Extract tasks..."
)
net
,
params
,
shape
,
out_shape
=
get_network
(
network
,
batch_size
=
1
)
net
,
params
,
shape
,
out_shape
=
get_network
(
network
,
batch_size
=
1
)
tasks
=
autotvm
.
task
.
extract_from_graph
(
net
,
shape
=
shape
,
dtype
=
dtype
,
tasks
=
autotvm
.
task
.
extract_from_graph
(
net
,
shape
=
shape
,
dtype
=
dtype
,
symbols
=
(
nnvm
.
sym
.
conv2d
,),
symbols
=
(
nnvm
.
sym
.
conv2d
,),
target
=
target
)
target
=
target
)
# run tuning tasks
# run tuning tasks
print
(
"Tuning..."
)
tune_tasks
(
tasks
,
**
tuning_option
)
tune_tasks
(
tasks
,
**
tuning_option
)
# compile kernels with history best records
# compile kernels with history best records
...
@@ -329,6 +333,7 @@ def tune_and_evaluate():
...
@@ -329,6 +333,7 @@ def tune_and_evaluate():
# We do not run the tuning in our webpage server since it takes too long.
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself.
# Uncomment the following line to run by yourself.
# tune_and_evaluate()
# tune_and_evaluate()
######################################################################
######################################################################
...
@@ -341,6 +346,8 @@ def tune_and_evaluate():
...
@@ -341,6 +346,8 @@ def tune_and_evaluate():
#
#
# .. code-block:: bash
# .. code-block:: bash
#
#
# Extract tasks...
# Tuning...
# [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done.
# [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done.
# [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done.
# [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done.
# [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done.
# [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done.
...
@@ -362,3 +369,23 @@ def tune_and_evaluate():
...
@@ -362,3 +369,23 @@ def tune_and_evaluate():
# Evaluate inference time cost...
# Evaluate inference time cost...
# Mean inference time (std dev): 156.51 ms (0.89 ms)
# Mean inference time (std dev): 156.51 ms (0.89 ms)
#
#
######################################################################
#
# .. note:: **Meet some problems?**
#
# The auto tuning module is error prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
tutorials/autotvm/tune_simple_template.py
View file @
136061dc
...
@@ -267,8 +267,9 @@ print(task.config_space)
...
@@ -267,8 +267,9 @@ print(task.config_space)
# We will log the tuning results into a log file. This file can be
# We will log the tuning results into a log file. This file can be
# used to get the best config later.
# used to get the best config later.
# logging config (for printing tuning log to screen)
# logging config (for printing tuning log to the screen)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
stream
=
sys
.
stdout
)
logging
.
getLogger
(
'autotvm'
)
.
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'autotvm'
)
.
addHandler
(
logging
.
StreamHandler
(
sys
.
stdout
))
# use local cpu, measure 5 times for every config to reduce variance
# use local cpu, measure 5 times for every config to reduce variance
measure_option
=
autotvm
.
measure_option
(
'local'
,
measure_option
=
autotvm
.
measure_option
(
'local'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment