Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
136061dc
Commit
136061dc
authored
Aug 03, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Aug 03, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[AUTOTVM] Improve tutorial and logging (#1544)
parent
33606741
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
181 additions
and
97 deletions
+181
-97
python/tvm/autotvm/measure/__init__.py
+1
-1
python/tvm/autotvm/measure/measure_methods.py
+44
-4
python/tvm/autotvm/record.py
+7
-6
python/tvm/autotvm/task/dispatcher.py
+5
-1
python/tvm/autotvm/tophub.py
+4
-2
python/tvm/autotvm/tuner/callback.py
+6
-3
python/tvm/autotvm/tuner/sa_model_optimizer.py
+6
-4
python/tvm/autotvm/tuner/tuner.py
+16
-6
python/tvm/autotvm/tuner/xgboost_cost_model.py
+7
-5
python/tvm/autotvm/util.py
+4
-3
python/tvm/rpc/base.py
+5
-6
python/tvm/rpc/proxy.py
+3
-2
python/tvm/rpc/server.py
+19
-31
python/tvm/rpc/tracker.py
+11
-9
tutorials/autotvm/tune_conv2d_cuda.py
+2
-1
tutorials/autotvm/tune_nnvm_arm.py
+38
-11
tutorials/autotvm/tune_simple_template.py
+3
-2
No files found.
python/tvm/autotvm/measure/__init__.py
View file @
136061dc
"""Distributed executor infrastructure to scale up the tuning"""
from
.measure
import
MeasureInput
,
MeasureResult
,
MeasureErrorNo
,
measure_option
from
.measure_methods
import
request_remote
,
create_measure_batch
,
use_rpc
from
.measure_methods
import
request_remote
,
c
heck_remote
,
c
reate_measure_batch
,
use_rpc
from
.local_executor
import
LocalExecutor
from
.executor
import
Future
,
Executor
python/tvm/autotvm/measure/measure_methods.py
View file @
136061dc
...
...
@@ -9,6 +9,7 @@ import logging
import
os
import
time
from
random
import
getrandbits
import
threading
import
numpy
as
np
...
...
@@ -23,6 +24,7 @@ from ..task.space import InstantiationError
from
.measure
import
MeasureResult
,
MeasureErrorNo
from
.local_executor
import
LocalExecutor
logger
=
logging
.
getLogger
(
'autotvm'
)
class
HashMismatchError
(
ValueError
):
"""Raised when the code hash of a submitted config doesn't match that on the
...
...
@@ -42,9 +44,9 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
priority of this request, larger is more prior
The
priority of this request, larger is more prior
timeout: float, optional
timeout of this session (units: seconds)
The
timeout of this session (units: seconds)
Returns
------
...
...
@@ -63,6 +65,33 @@ def request_remote(device_key, tracker_addr=None, priority=1, timeout=60):
session_timeout
=
timeout
)
return
remote
def
check_remote
(
target
,
device_key
,
tracker_addr
=
None
,
priority
=
2
,
timeout
=
10
):
"""
Check the availability of a remote device
Parameters
----------
target: Target
The wanted compilation target
device_key: string
device key of registered device in tracker
tracker_addr: Tuple(string, int), optional
The address of rpc tracker in (host, port) format.
If is none, will use environment variable "TVM_TRACKER_HOST"
and "TVM_TRACKER_PORT"
priority: int, optional
The priority of this request, larger is more prior
timeout: float, optional
The timeout of this check (units: seconds).
If time is out, a RuntimerError will be raised.
"""
def
_check
():
remote
=
request_remote
(
device_key
,
tracker_addr
,
priority
)
remote
.
context
(
str
(
target
))
t
=
threading
.
Thread
(
target
=
_check
,)
t
.
start
()
t
.
join
(
timeout
)
return
not
t
.
is_alive
()
def
create_measure_batch
(
task
,
option
):
"""Get a standard measure_batch function.
...
...
@@ -115,6 +144,17 @@ def create_measure_batch(task, option):
build_func
=
default_build_func
build_kwargs
[
'use_ndk'
]
=
True
# check the availability of remote devices
if
hasattr
(
measure_func
,
'rpc_info'
):
rpc_info
=
measure_func
.
rpc_info
if
check_remote
(
task
.
target
,
rpc_info
[
'key'
],
(
rpc_info
[
'host'
],
rpc_info
[
'port'
])):
logger
.
info
(
"Get devices for measurement successfully!"
)
else
:
raise
RuntimeError
(
"Cannot get remote devices from the tracker. "
"Please check the status of tracker by "
"'python -m tvm.exec.query_rpc_tracker --port [THE PORT YOU USE]' "
"and make sure you have free devices on the queue status."
)
# add device info of cuda and opencl target
if
(
'cuda'
in
task
.
target
.
keys
or
'opencl'
in
task
.
target
.
keys
)
\
and
hasattr
(
measure_func
,
'rpc_info'
):
...
...
@@ -313,7 +353,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
continue
except
InstantiationError
as
e
:
tstamp
=
time
.
time
()
res_pack
.
append
(
MeasureResult
((
e
,),
res_pack
.
append
(
MeasureResult
((
InstantiationError
(
str
(
e
))
,),
MeasureErrorNo
.
INSTANTIATION_ERROR
,
tstamp
-
tic
,
tstamp
))
continue
...
...
@@ -346,7 +386,7 @@ def _measure_common(input_pack, build_func, build_kwargs, number, repeat,
if
ref_output
:
for
expected
,
real
in
zip
(
ref_output
,
args
):
if
not
np
.
allclose
(
expected
,
real
.
asnumpy
(),
rtol
=
1e-4
):
logg
ing
.
warning
(
"Wrong Answer!"
)
logg
er
.
warning
(
"Wrong Answer!"
)
errno
=
MeasureErrorNo
.
WRONG_ANSWER
except
TVMError
as
exc
:
msg
=
str
(
exc
)
...
...
python/tvm/autotvm/record.py
View file @
136061dc
...
...
@@ -18,6 +18,7 @@ from .task import ConfigEntity, ApplyHistoryBest
from
.measure
import
MeasureInput
,
MeasureResult
AUTOTVM_LOG_VERSION
=
0.1
logger
=
logging
.
getLogger
(
'autotvm'
)
try
:
# convert unicode to str for python2
_unicode
=
unicode
...
...
@@ -181,10 +182,10 @@ def split_workload(in_file, clean=True):
tic
=
time
.
time
()
lines
=
list
(
open
(
in_file
)
.
readlines
())
logg
ing
.
info
(
"start converting..."
)
logg
er
.
info
(
"start converting..."
)
pool
=
multiprocessing
.
Pool
()
lines
=
pool
.
map
(
decode
,
lines
)
logg
ing
.
info
(
"map done
%.2
f"
,
time
.
time
()
-
tic
)
logg
er
.
info
(
"map done
%.2
f"
,
time
.
time
()
-
tic
)
wkl_dict
=
OrderedDict
()
for
inp
,
res
in
lines
:
...
...
@@ -206,13 +207,13 @@ def split_workload(in_file, clean=True):
cleaned
.
append
([
inp
,
res
])
# write to file
logg
ing
.
info
(
"Key:
%
s
\t
Valid:
%
d
\t
Dup:
%
d
\t
"
,
k
,
len
(
cleaned
),
len
(
v
)
-
len
(
cleaned
))
logg
er
.
info
(
"Key:
%
s
\t
Valid:
%
d
\t
Dup:
%
d
\t
"
,
k
,
len
(
cleaned
),
len
(
v
)
-
len
(
cleaned
))
with
open
(
args
.
i
+
".
%03
d.wkl"
%
i
,
'w'
)
as
fout
:
for
inp
,
res
in
cleaned
:
fout
.
write
(
encode
(
inp
,
res
)
+
'
\n
'
)
else
:
for
i
,
(
k
,
v
)
in
enumerate
(
wkl_dict
.
items
()):
logg
ing
.
info
(
"Key:
%
s
\t
Num:
%
d"
,
k
,
len
(
v
))
logg
er
.
info
(
"Key:
%
s
\t
Num:
%
d"
,
k
,
len
(
v
))
with
open
(
args
.
i
+
".
%03
d.wkl"
%
i
,
'w'
)
as
fout
:
for
inp
,
res
in
v
:
fout
.
write
(
encode
(
inp
,
res
)
+
'
\n
'
)
...
...
@@ -238,7 +239,7 @@ def pick_best(in_file, out_file):
for
v
in
best_context
.
best_by_targetkey
.
values
():
best_set
.
add
(
measure_str_key
(
v
[
0
]))
logg
ing
.
info
(
"Extract
%
d best records from the
%
s"
,
len
(
best_set
),
in_file
)
logg
er
.
info
(
"Extract
%
d best records from the
%
s"
,
len
(
best_set
),
in_file
)
fout
=
open
(
out_file
,
'w'
)
if
isinstance
(
out_file
,
str
)
else
out_file
for
inp
,
res
in
load_from_file
(
in_file
):
...
...
@@ -270,7 +271,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--code"
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
logg
ing
.
basicConfig
(
level
=
logging
.
INFO
)
logg
er
.
basicConfig
(
level
=
logger
.
INFO
)
if
args
.
mode
==
'pick'
:
args
.
o
=
args
.
o
or
args
.
i
+
".best.log"
...
...
python/tvm/autotvm/task/dispatcher.py
View file @
136061dc
...
...
@@ -10,6 +10,8 @@ of the DispatchContext base class.
- During search, we can use it to pass the current proposal from tuner.
- During evaluation, we can use it to set pick the best policy.
"""
# pylint: disable=invalid-name
from
__future__
import
absolute_import
as
_abs
import
logging
...
...
@@ -19,6 +21,8 @@ import numpy as np
from
tvm
import
target
as
_target
logger
=
logging
.
getLogger
(
'autotvm'
)
class
DispatchContext
(
object
):
"""
Base class of dispatch context.
...
...
@@ -216,7 +220,7 @@ class ApplyHistoryBest(DispatchContext):
best_by_model
[
key
]
=
(
inp
,
res
)
break
logg
ing
.
debug
(
"Finish loading
%
d records"
,
counter
)
logg
er
.
debug
(
"Finish loading
%
d records"
,
counter
)
def
query
(
self
,
target
,
workload
):
if
target
is
None
:
...
...
python/tvm/autotvm/tophub.py
View file @
136061dc
...
...
@@ -4,6 +4,7 @@ To get the best performance, we typically need auto-tuning for the specific devi
TVM releases pre-tuned parameters in TopHub for some common networks and hardware targets.
TVM will download these parameters for you when you create the target for the first time.
"""
# pylint: disable=invalid-name
import
logging
import
os
...
...
@@ -16,6 +17,7 @@ from ..contrib.download import download
AUTOTVM_TOPHUB_ROOT_PATH
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
".tvm"
,
"tophub"
)
logger
=
logging
.
getLogger
(
'autotvm'
)
def
_alias
(
name
):
"""convert alias for some packages"""
...
...
@@ -79,7 +81,7 @@ def download_package(backend):
os
.
mkdir
(
path
)
backend
=
_alias
(
backend
)
logg
ing
.
info
(
"Download pre-tuned parameters for
%
s"
,
backend
)
logg
er
.
info
(
"Download pre-tuned parameters for
%
s"
,
backend
)
download
(
"https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/
%
s.log"
%
backend
,
os
.
path
.
join
(
rootpath
,
backend
+
".log"
),
True
,
verbose
=
0
)
...
...
@@ -110,7 +112,7 @@ def list_packages():
"""
path
=
tempdir
()
filename
=
path
.
relpath
(
"info.json"
)
logg
ing
.
info
(
"Download meta info for pre-tuned parameters"
)
logg
er
.
info
(
"Download meta info for pre-tuned parameters"
)
download
(
"https://raw.githubusercontent.com/uwsaml/tvm-distro/master/tophub/info.json"
,
filename
,
True
,
verbose
=
0
)
...
...
python/tvm/autotvm/tuner/callback.py
View file @
136061dc
...
...
@@ -2,11 +2,13 @@
"""Namespace of callback utilities of AutoTVM"""
import
sys
import
time
import
logging
import
numpy
as
np
from
..
import
record
logger
=
logging
.
getLogger
(
'autotvm'
)
def
log_to_file
(
file_out
,
protocol
=
'json'
):
"""Log the tuning records into file.
...
...
@@ -90,7 +92,7 @@ def progress_bar(total, prefix=''):
prefix: str
The prefix of output message
"""
class
_Context
:
class
_Context
(
object
)
:
"""Context to store local variables"""
def
__init__
(
self
):
self
.
best_flops
=
0
...
...
@@ -112,11 +114,12 @@ def progress_bar(total, prefix=''):
if
res
.
error_no
==
0
:
flops
=
inp
.
task
.
flop
/
np
.
mean
(
res
.
costs
)
if
logger
.
level
<
logging
.
DEBUG
:
# only print progress bar in non-debug mode
ctx
.
cur_flops
=
flops
ctx
.
best_flops
=
tuner
.
best_flops
sys
.
stdout
.
write
(
'
\r
%
s Current/Best:
%7.2
f/
%7.2
f GFLOPS | Progress: (
%
d/
%
d) '
'|
%.2
f s
'
%
sys
.
stdout
.
write
(
'
%
s Current/Best:
%7.2
f/
%7.2
f GFLOPS | Progress: (
%
d/
%
d) '
'|
%.2
f s
\r
'
%
(
prefix
,
ctx
.
cur_flops
/
1e9
,
ctx
.
best_flops
/
1e9
,
ctx
.
ct
,
ctx
.
total
,
time
.
time
()
-
tic
))
sys
.
stdout
.
flush
()
...
...
python/tvm/autotvm/tuner/sa_model_optimizer.py
View file @
136061dc
# pylint: disable=consider-using-enumerate
# pylint: disable=consider-using-enumerate
, invalid-name
"""
Cost model optimizer based on simulated annealing
"""
...
...
@@ -12,6 +12,8 @@ import numpy as np
from
..util
import
sample_ints
from
.model_based_tuner
import
ModelOptimizer
,
knob2point
,
point2knob
logger
=
logging
.
getLogger
(
'autotvm'
)
class
SimulatedAnnealingOptimizer
(
ModelOptimizer
):
"""parallel simulated annealing optimization algorithm
...
...
@@ -103,16 +105,16 @@ class SimulatedAnnealingOptimizer(ModelOptimizer):
if
log_interval
and
k
%
log_interval
==
0
:
t_str
=
"
%.2
f"
%
t
logg
ing
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
temp:
%
s
\t
"
logg
er
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
temp:
%
s
\t
"
"elapsed:
%.2
f"
,
k
,
k_last_modify
,
heap_items
[
0
][
0
],
np
.
max
([
v
for
v
,
_
in
heap_items
]),
t_str
,
time
.
time
()
-
tic
)
heap_items
.
sort
(
key
=
lambda
item
:
-
item
[
0
])
logg
ing
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
elapsed:
%.2
f"
,
logg
er
.
debug
(
"SA iter:
%
d
\t
last_update:
%
d
\t
max-0:
%.2
f
\t
max-1:
%.2
f
\t
elapsed:
%.2
f"
,
k
,
k_last_modify
,
heap_items
[
-
1
][
0
],
heap_items
[
0
][
0
],
time
.
time
()
-
tic
)
logg
ing
.
debug
(
"SA Maximums:
%
s"
,
heap_items
)
logg
er
.
debug
(
"SA Maximums:
%
s"
,
heap_items
)
if
self
.
persistent
:
self
.
points
=
points
...
...
python/tvm/autotvm/tuner/tuner.py
View file @
136061dc
...
...
@@ -4,11 +4,12 @@ import logging
import
numpy
as
np
from
..measure
import
MeasureInput
from
..measure
import
create_measure_batch
from
..measure
import
MeasureInput
,
create_measure_batch
from
..env
import
GLOBAL_SCOPE
logger
=
logging
.
getLogger
(
'autotvm'
)
class
Tuner
(
object
):
"""Base class for tuners
...
...
@@ -86,9 +87,10 @@ class Tuner(object):
measure_batch
=
create_measure_batch
(
self
.
task
,
measure_option
)
parallel_num
=
getattr
(
measure_batch
,
'parallel_num'
,
1
)
early_stopping
=
early_stopping
or
1e9
old_level
=
logger
.
level
GLOBAL_SCOPE
.
in_tuning
=
True
i
=
0
i
=
error_ct
=
0
while
i
<
n_trial
:
if
not
self
.
has_next
():
break
...
...
@@ -103,15 +105,18 @@ class Tuner(object):
config
=
inp
.
config
if
res
.
error_no
==
0
:
flops
=
inp
.
task
.
flop
/
np
.
mean
(
res
.
costs
)
error_ct
=
0
else
:
flops
=
0
error_ct
+=
1
if
flops
>
self
.
best_flops
:
self
.
best_flops
=
flops
self
.
best_config
=
config
self
.
best_measure_pair
=
(
inp
,
res
)
self
.
best_iter
=
i
+
k
logg
ing
.
debug
(
"No:
%
d
\t
GFLOPS:
%.2
f/
%.2
f
\t
result:
%
s
\t
%
s"
,
logg
er
.
debug
(
"No:
%
d
\t
GFLOPS:
%.2
f/
%.2
f
\t
result:
%
s
\t
%
s"
,
i
+
k
+
1
,
flops
/
1e9
,
self
.
best_flops
/
1e9
,
res
,
config
)
...
...
@@ -123,11 +128,16 @@ class Tuner(object):
callback
(
self
,
inputs
,
results
)
if
i
>
self
.
best_iter
+
early_stopping
:
logg
ing
.
debug
(
"Early stopped. Best iter:
%
d."
,
self
.
best_iter
)
logg
er
.
debug
(
"Early stopped. Best iter:
%
d."
,
self
.
best_iter
)
break
GLOBAL_SCOPE
.
in_tuning
=
False
if
error_ct
>
50
:
logger
.
warning
(
"Too many errors happen in the tuning. Now is in debug mode"
)
logger
.
setLevel
(
logging
.
DEBUG
)
else
:
logger
.
setLevel
(
old_level
)
GLOBAL_SCOPE
.
in_tuning
=
False
del
measure_batch
def
reset
(
self
):
...
...
python/tvm/autotvm/tuner/xgboost_cost_model.py
View file @
136061dc
...
...
@@ -16,6 +16,8 @@ from ..util import get_rank
from
.metric
import
max_curve
,
recall_curve
,
cover_curve
from
.model_based_tuner
import
CostModel
,
FeatureCache
logger
=
logging
.
getLogger
(
'autotvm'
)
class
XGBoostCostModel
(
CostModel
):
"""XGBoost as cost model
...
...
@@ -163,7 +165,7 @@ class XGBoostCostModel(CostModel):
],
verbose_eval
=
self
.
log_interval
)])
logg
ing
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d
\t
error:
%
d
\t
n_cache:
%
d"
,
logg
er
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d
\t
error:
%
d
\t
n_cache:
%
d"
,
time
.
time
()
-
tic
,
len
(
xs
),
len
(
xs
)
-
np
.
sum
(
valid_index
),
self
.
feature_cache
.
size
(
self
.
fea_type
))
...
...
@@ -173,7 +175,7 @@ class XGBoostCostModel(CostModel):
self
.
_reset_pool
()
args
=
list
(
records
)
logg
ing
.
debug
(
"XGB load
%
d entries from history log file"
,
len
(
args
))
logg
er
.
debug
(
"XGB load
%
d entries from history log file"
,
len
(
args
))
if
self
.
fea_type
==
'itervar'
:
feature_extract_func
=
_extract_itervar_feature_log
...
...
@@ -208,7 +210,7 @@ class XGBoostCostModel(CostModel):
],
verbose_eval
=
self
.
log_interval
)])
logg
ing
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d"
,
time
.
time
()
-
tic
,
len
(
xs
))
logg
er
.
debug
(
"XGB train:
%.2
f
\t
obs:
%
d"
,
time
.
time
()
-
tic
,
len
(
xs
))
def
predict
(
self
,
xs
,
output_margin
=
False
):
feas
=
self
.
_get_feature
(
xs
)
...
...
@@ -403,7 +405,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
infos
.
append
(
"
%
s:
%.6
f"
%
(
item
[
0
],
item
[
1
]))
if
not
isinstance
(
verbose_eval
,
bool
)
and
verbose_eval
and
i
%
verbose_eval
==
0
:
logg
ing
.
debug
(
"
\t
"
.
join
(
infos
))
logg
er
.
debug
(
"
\t
"
.
join
(
infos
))
if
log_file
:
with
open
(
log_file
,
"a"
)
as
fout
:
fout
.
write
(
"
\t
"
.
join
(
infos
)
+
'
\n
'
)
...
...
@@ -435,7 +437,7 @@ def custom_callback(stopping_rounds, metric, fevals, evals=(), log_file=None,
elif
env
.
iteration
-
best_iteration
>=
stopping_rounds
:
best_msg
=
state
[
'best_msg'
]
if
verbose_eval
and
env
.
rank
==
0
:
logg
ing
.
debug
(
"XGB stopped. Best iteration:
%
s "
,
best_msg
)
logg
er
.
debug
(
"XGB stopped. Best iteration:
%
s "
,
best_msg
)
raise
EarlyStopException
(
best_iteration
)
return
callback
...
...
python/tvm/autotvm/util.py
View file @
136061dc
...
...
@@ -8,6 +8,7 @@ import numpy as np
from
..
import
expr
,
ir_pass
logger
=
logging
.
getLogger
(
'autotvm'
)
class
EmptyContext
(
object
):
"""An empty context"""
...
...
@@ -92,15 +93,15 @@ def pool_map(func, args, batch_size, verbose=False, pool=None):
tic
=
time
.
time
()
local_pool
=
pool
or
multiprocessing
.
Pool
()
if
verbose
:
logg
ing
.
info
(
"mapping begin"
)
logg
er
.
info
(
"mapping begin"
)
for
i
in
range
(
0
,
len
(
args
),
batch_size
):
if
verbose
:
logg
ing
.
info
(
"mapping
%
d/
%
d elapsed
%.2
f"
,
i
,
len
(
args
),
logg
er
.
info
(
"mapping
%
d/
%
d elapsed
%.2
f"
,
i
,
len
(
args
),
time
.
time
()
-
tic
)
tmp
=
np
.
array
(
local_pool
.
map
(
func
,
args
[
i
:
i
+
batch_size
]))
ret
=
tmp
if
ret
is
None
else
np
.
concatenate
((
ret
,
tmp
))
if
verbose
:
logg
ing
.
info
(
"mapping done"
)
logg
er
.
info
(
"mapping done"
)
if
not
pool
:
local_pool
.
close
()
return
ret
...
...
python/tvm/rpc/base.py
View file @
136061dc
"""Base definitions for RPC."""
# pylint: disable=invalid-name
from
__future__
import
absolute_import
import
socket
...
...
@@ -23,6 +25,7 @@ RPC_CODE_DUPLICATE = RPC_MAGIC + 1
# cannot found matched key in server
RPC_CODE_MISMATCH
=
RPC_MAGIC
+
2
logger
=
logging
.
getLogger
(
'RPCServer'
)
class
TrackerCode
(
object
):
"""Enumeration code for the RPC tracker"""
...
...
@@ -120,7 +123,7 @@ def random_key(prefix, cmap=None):
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
,
silent
=
False
):
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
...
...
@@ -135,9 +138,6 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
retry_period : float
Number of seconds before we retry again.
silent: bool
whether run in silent mode
"""
tstart
=
time
.
time
()
while
True
:
...
...
@@ -152,8 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5, silent=False):
if
period
>
timeout
:
raise
RuntimeError
(
"Failed to connect to server
%
s"
%
str
(
addr
))
if
not
silent
:
logging
.
info
(
"Cannot connect to tracker
%
s, retry in
%
g secs..."
,
logger
.
warning
(
"Cannot connect to tracker
%
s, retry in
%
g secs..."
,
str
(
addr
),
retry_period
)
time
.
sleep
(
retry_period
)
...
...
python/tvm/rpc/proxy.py
View file @
136061dc
...
...
@@ -23,7 +23,8 @@ try:
from
tornado
import
ioloop
from
.
import
tornado_util
except
ImportError
as
error_msg
:
raise
ImportError
(
"RPCProxy module requires tornado package
%
s"
%
error_msg
)
raise
ImportError
(
"RPCProxy module requires tornado package
%
s. Try 'pip install tornado'."
%
error_msg
)
from
.
import
base
from
.base
import
TrackerCode
...
...
@@ -540,7 +541,7 @@ def websocket_proxy_server(url, key=""):
def
_connect
(
key
):
conn
=
yield
websocket
.
websocket_connect
(
url
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
,
None
)
temp
=
_server_env
(
None
)
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'<i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
...
...
python/tvm/rpc/server.py
View file @
136061dc
...
...
@@ -8,6 +8,8 @@ Server is TCP based with the following protocol:
- The key is in format
- {server|client}:device-type[:random-key] [-timeout=timeout]
"""
# pylint: disable=invalid-name
from
__future__
import
absolute_import
import
os
...
...
@@ -30,11 +32,11 @@ from ..contrib import util
from
.
import
base
from
.
base
import
TrackerCode
def
_server_env
(
load_library
,
logger
):
logger
=
logging
.
getLogger
(
'RPCServer'
)
def
_server_env
(
load_library
):
"""Server environment function return temp dir"""
temp
=
util
.
tempdir
()
if
logger
is
None
:
logger
=
logging
.
getLogger
()
# pylint: disable=unused-variable
@register_func
(
"tvm.rpc.server.workpath"
)
...
...
@@ -59,13 +61,10 @@ def _server_env(load_library, logger):
return
temp
def
_serve_loop
(
sock
,
addr
,
load_library
,
silent
):
def
_serve_loop
(
sock
,
addr
,
load_library
):
"""Server loop"""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
sockfd
=
sock
.
fileno
()
temp
=
_server_env
(
load_library
,
logger
)
temp
=
_server_env
(
load_library
)
base
.
_ServerLoop
(
sockfd
)
temp
.
remove
()
logger
.
info
(
"Finish serving
%
s"
,
addr
)
...
...
@@ -79,12 +78,8 @@ def _parse_server_opt(opts):
ret
[
"timeout"
]
=
float
(
kv
[
9
:])
return
ret
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
,
silent
):
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
):
"""Listening loop of the server master."""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
"""Accept connection from the other places.
...
...
@@ -148,7 +143,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
logger
.
info
(
"mismatch key from
%
s"
,
addr
)
logger
.
warning
(
"mismatch key from
%
s"
,
addr
)
continue
else
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_SUCCESS
))
...
...
@@ -162,7 +157,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
try
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
,
silent
=
silent
)
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
...
...
@@ -182,15 +177,12 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
tracker_conn
=
None
continue
except
RuntimeError
as
exc
:
if
silent
:
return
else
:
raise
exc
# step 3: serving
logger
.
info
(
"connection from
%
s"
,
addr
)
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
,
load_library
,
silent
))
args
=
(
conn
,
addr
,
load_library
))
server_proc
.
deamon
=
True
server_proc
.
start
()
# close from our side.
...
...
@@ -202,10 +194,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr, s
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
,
load_library
,
silent
):
logger
=
logging
.
getLogger
(
"RPCProxy"
)
if
silent
:
logger
.
disabled
=
True
def
_connect_proxy_loop
(
addr
,
key
,
load_library
):
key
=
"server:"
+
key
retry_count
=
0
max_retry
=
5
...
...
@@ -221,7 +210,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logger
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
logger
.
warning
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
sock
,
4
))[
0
]
...
...
@@ -229,7 +218,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logger
.
info
(
"connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
,
silent
))
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
))
process
.
deamon
=
True
process
.
start
()
sock
.
close
()
...
...
@@ -240,7 +229,7 @@ def _connect_proxy_loop(addr, key, load_library, silent):
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
logger
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
logger
.
warning
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
...
...
@@ -323,9 +312,8 @@ class Server(object):
self
.
custom_addr
=
custom_addr
self
.
use_popen
=
use_popen
self
.
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
self
.
logger
.
disabled
=
True
logger
.
setLevel
(
logging
.
WARN
)
if
use_popen
:
cmd
=
[
sys
.
executable
,
...
...
@@ -360,18 +348,18 @@ class Server(object):
raise
sock_err
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
self
.
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
self
.
sock
=
sock
self
.
proc
=
multiprocessing
.
Process
(
target
=
_listen_loop
,
args
=
(
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
,
self
.
custom_addr
,
silent
))
self
.
custom_addr
))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
else
:
self
.
proc
=
multiprocessing
.
Process
(
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
,
silent
))
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
...
...
python/tvm/rpc/tracker.py
View file @
136061dc
...
...
@@ -23,6 +23,8 @@ List of available APIs:
- input: [TrackerCode.REQUEST, [key, user, priority]]
- return: [TrackerCode.SUCCESS, [url, port, match-key]]
"""
# pylint: disable=invalid-name
import
heapq
import
time
import
logging
...
...
@@ -37,12 +39,13 @@ try:
from
.
import
tornado_util
except
ImportError
as
error_msg
:
raise
ImportError
(
"RPCTracker module requires tornado package
%
s"
%
error_msg
)
"RPCTracker module requires tornado package
%
s
. Try 'pip install tornado'.
"
%
error_msg
)
from
.._ffi.base
import
py_str
from
.
import
base
from
.base
import
RPC_TRACKER_MAGIC
,
TrackerCode
logger
=
logging
.
getLogger
(
"RPCTracker"
)
class
Scheduler
(
object
):
"""Abstratc interface of scheduler."""
...
...
@@ -141,11 +144,11 @@ class TCPEventHandler(tornado_util.TCPHandler):
def
_init_conn
(
self
,
message
):
"""Initialie the connection"""
if
len
(
message
)
!=
4
:
logg
ing
.
info
(
"Invalid connection from
%
s"
,
self
.
name
())
logg
er
.
warning
(
"Invalid connection from
%
s"
,
self
.
name
())
self
.
close
()
magic
=
struct
.
unpack
(
'<i'
,
message
)[
0
]
if
magic
!=
RPC_TRACKER_MAGIC
:
logg
ing
.
info
(
"Invalid magic from
%
s"
,
self
.
name
())
logg
er
.
warning
(
"Invalid magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
write_message
(
struct
.
pack
(
'<i'
,
RPC_TRACKER_MAGIC
),
binary
=
True
)
self
.
_init_req_nbytes
=
0
...
...
@@ -232,14 +235,14 @@ class TCPEventHandler(tornado_util.TCPHandler):
status
=
self
.
_tracker
.
summary
()
self
.
ret_value
([
TrackerCode
.
SUCCESS
,
status
])
else
:
logg
ing
.
info
(
"Unknown tracker code
%
d"
,
code
)
logg
er
.
warning
(
"Unknown tracker code
%
d"
,
code
)
self
.
close
()
def
on_close
(
self
):
self
.
_tracker
.
_connections
.
remove
(
self
)
def
on_error
(
self
,
err
):
logg
ing
.
info
(
"
%
s: Error in RPC Tracker:
%
s"
,
self
.
name
(),
err
)
logg
er
.
warning
(
"
%
s: Error in RPC Tracker:
%
s"
,
self
.
name
(),
err
)
self
.
close
()
...
...
@@ -335,9 +338,8 @@ class Tracker(object):
port
=
9190
,
port_end
=
9199
,
silent
=
False
):
self
.
logger
=
logging
.
getLogger
(
"RPCTracker"
)
if
silent
:
self
.
logger
.
disabled
=
True
logger
.
setLevel
(
logging
.
WARN
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
...
...
@@ -354,7 +356,7 @@ class Tracker(object):
raise
sock_err
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
self
.
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
logger
.
info
(
"bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
self
.
proc
=
multiprocessing
.
Process
(
target
=
_tracker_server
,
args
=
(
sock
,
self
.
stop_key
))
...
...
@@ -380,7 +382,7 @@ class Tracker(object):
self
.
_stop_tracker
()
self
.
proc
.
join
(
1
)
if
self
.
proc
.
is_alive
():
self
.
logger
.
info
(
"Terminating Tracker Server..."
)
logger
.
info
(
"Terminating Tracker Server..."
)
self
.
proc
.
terminate
()
self
.
proc
=
None
...
...
tutorials/autotvm/tune_conv2d_cuda.py
View file @
136061dc
...
...
@@ -154,7 +154,8 @@ def conv2d_no_batching(N, H, W, CI, CO, KH, KW, stride, padding):
# for this template
# logging config (for printing tuning log to screen)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
stream
=
sys
.
stdout
)
logging
.
getLogger
(
'autotvm'
)
.
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'autotvm'
)
.
addHandler
(
logging
.
StreamHandler
(
sys
.
stdout
))
# the last layer in resnet
N
,
H
,
W
,
CO
,
CI
,
KH
,
KW
,
strides
,
padding
=
1
,
7
,
7
,
512
,
512
,
3
,
3
,
(
1
,
1
),
(
1
,
1
)
...
...
tutorials/autotvm/tune_nnvm_arm.py
View file @
136061dc
...
...
@@ -163,8 +163,10 @@ def get_network(name, batch_size):
# Set Tuning Options
# ------------------
# Before tuning, we should do some configurations. Here I use an RK3399 board
# in our environment as example. In your setting, you should modify the target
# and device_key accordingly.
# as example. In your setting, you should modify the target and device_key accordingly.
# set :code:`use_android` to True if you use android phone.
#### DEVICE CONFIG ####
# Replace "aarch64-linux-gnu" with the correct target of your board.
# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
...
...
@@ -173,7 +175,10 @@ target = tvm.target.create('llvm -device=arm_cpu -target=aarch64-linux-gnu')
# Also replace this with the device key in your tracker
device_key
=
'rk3399'
# tuning option
# Set this to True if you use android phone
use_android
=
False
#### TUNING OPTION ####
network
=
'resnet-18'
log_file
=
"
%
s.
%
s.log"
%
(
device_key
,
network
)
dtype
=
'float32'
...
...
@@ -181,17 +186,17 @@ dtype = 'float32'
tuning_option
=
{
'log_filename'
:
log_file
,
'tuner'
:
'xgb'
,
'tuner'
:
'xgb'
,
'n_trial'
:
1000
,
'early_stopping'
:
2
0
0
,
'early_stopping'
:
2
5
0
,
'measure_option'
:
autotvm
.
measure_option
(
autotvm
.
use_rpc
(
device_key
,
host
=
'localhost'
,
port
=
9190
),
number
=
4
,
parallel_num
=
1
,
timeout
=
10
)
,
'use_transfer_learning'
:
True
,
timeout
=
10
,
build_func
=
'ndk'
if
use_android
else
'default'
,
)
,
}
####################################################################
...
...
@@ -208,9 +213,6 @@ tuning_option = {
# If your device is very slow or a single conv2d operator in your network has large FLOPs,
# consider setting timeout larger.
#
# **For android phone**, add :code:`build_func='ndk'` to the argument list of
# :code:`autotvm.measure_option` to use Android NDK for creating shared library.
#
###################################################################
# Begin Tuning
...
...
@@ -280,12 +282,14 @@ def tune_tasks(tasks,
def
tune_and_evaluate
():
# extract workloads from nnvm graph
print
(
"Extract tasks..."
)
net
,
params
,
shape
,
out_shape
=
get_network
(
network
,
batch_size
=
1
)
tasks
=
autotvm
.
task
.
extract_from_graph
(
net
,
shape
=
shape
,
dtype
=
dtype
,
symbols
=
(
nnvm
.
sym
.
conv2d
,),
target
=
target
)
# run tuning tasks
print
(
"Tuning..."
)
tune_tasks
(
tasks
,
**
tuning_option
)
# compile kernels with history best records
...
...
@@ -329,6 +333,7 @@ def tune_and_evaluate():
# We do not run the tuning in our webpage server since it takes too long.
# Uncomment the following line to run by yourself.
# tune_and_evaluate()
######################################################################
...
...
@@ -341,6 +346,8 @@ def tune_and_evaluate():
#
# .. code-block:: bash
#
# Extract tasks...
# Tuning...
# [Task 1/16] Current/Best: 13.15/ 20.49 GFLOPS | Progress: (297/1000) | 348.51 s Done.
# [Task 2/16] Current/Best: 16.66/ 22.64 GFLOPS | Progress: (475/1000) | 415.42 s Done.
# [Task 3/16] Current/Best: 10.33/ 14.19 GFLOPS | Progress: (306/1000) | 239.61 s Done.
...
...
@@ -362,3 +369,23 @@ def tune_and_evaluate():
# Evaluate inference time cost...
# Mean inference time (std dev): 156.51 ms (0.89 ms)
#
######################################################################
#
# .. note:: **Meet some problems?**
#
# The auto tuning module is error prone. If you always see " 0.00/ 0.00 GFLOPS",
# then there must be something wrong.
#
# First, make sure you set the correct configuration of your device.
# Then, you can print debug information by adding these lines in the beginning
# of the script. It will print every measurement result, where you can find useful
# error messages.
#
# .. code-block:: python
#
# import logging
# logging.getLogger('autotvm').setLevel(logging.DEBUG)
#
# Finally, always feel free to ask our community for help on https://discuss.tvm.ai
tutorials/autotvm/tune_simple_template.py
View file @
136061dc
...
...
@@ -267,8 +267,9 @@ print(task.config_space)
# We will log the tuning results into a log file. This file can be
# used to get the best config later.
# logging config (for printing tuning log to screen)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
,
stream
=
sys
.
stdout
)
# logging config (for printing tuning log to the screen)
logging
.
getLogger
(
'autotvm'
)
.
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'autotvm'
)
.
addHandler
(
logging
.
StreamHandler
(
sys
.
stdout
))
# use local cpu, measure 5 times for every config to reduce variance
measure_option
=
autotvm
.
measure_option
(
'local'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment