Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
79fc6672
Commit
79fc6672
authored
Apr 05, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RPC] Tracker status query (#1081)
parent
6bd8dbc7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
183 additions
and
12 deletions
+183
-12
python/tvm/contrib/rpc/base.py
+2
-1
python/tvm/contrib/rpc/client.py
+94
-6
python/tvm/contrib/rpc/server.py
+5
-0
python/tvm/contrib/rpc/tracker.py
+36
-3
python/tvm/exec/query_rpc_tracker.py
+32
-0
python/tvm/exec/rpc_tracker.py
+2
-2
tests/python/contrib/test_rpc_tracker.py
+12
-0
No files found.
python/tvm/contrib/rpc/base.py
View file @
79fc6672
...
@@ -32,7 +32,8 @@ class TrackerCode(object):
...
@@ -32,7 +32,8 @@ class TrackerCode(object):
STOP
=
2
STOP
=
2
PUT
=
3
PUT
=
3
REQUEST
=
4
REQUEST
=
4
UPDATE_INFO
=
5
SUMMARY
=
6
RPC_SESS_MASK
=
128
RPC_SESS_MASK
=
128
...
...
python/tvm/contrib/rpc/client.py
View file @
79fc6672
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import
...
@@ -4,6 +4,7 @@ from __future__ import absolute_import
import
os
import
os
import
socket
import
socket
import
struct
import
struct
import
time
from
.
import
base
from
.
import
base
from
..._ffi.base
import
TVMError
from
..._ffi.base
import
TVMError
...
@@ -150,7 +151,6 @@ class TrackerSession(object):
...
@@ -150,7 +151,6 @@ class TrackerSession(object):
def
__init__
(
self
,
addr
):
def
__init__
(
self
,
addr
):
self
.
_addr
=
addr
self
.
_addr
=
addr
self
.
_sock
=
None
self
.
_sock
=
None
self
.
_max_request_retry
=
5
self
.
_connect
()
self
.
_connect
()
def
__del__
(
self
):
def
__del__
(
self
):
...
@@ -169,7 +169,38 @@ class TrackerSession(object):
...
@@ -169,7 +169,38 @@ class TrackerSession(object):
self
.
_sock
.
close
()
self
.
_sock
.
close
()
self
.
_sock
=
None
self
.
_sock
=
None
def
request
(
self
,
key
,
priority
=
1
,
session_timeout
=
0
):
def
summary
(
self
):
"""Get the summary dict of the tracker."""
base
.
sendjson
(
self
.
_sock
,
[
base
.
TrackerCode
.
SUMMARY
])
value
=
base
.
recvjson
(
self
.
_sock
)
if
value
[
0
]
!=
base
.
TrackerCode
.
SUCCESS
:
raise
RuntimeError
(
"Invalid return value
%
s"
%
str
(
value
))
return
value
[
1
]
def
text_summary
(
self
):
"""Get a text summary of the tracker."""
data
=
self
.
summary
()
res
=
""
res
+=
"Server List
\n
"
res
+=
"----------------------------
\n
"
res
+=
"server-address
\t
key
\n
"
res
+=
"----------------------------
\n
"
for
item
in
data
[
"server_info"
]:
addr
=
item
[
"addr"
]
res
+=
addr
[
0
]
+
":"
+
str
(
addr
[
1
])
+
"
\t
"
res
+=
item
[
"key"
]
+
"
\n
"
res
+=
"----------------------------
\n
"
res
+=
"
\n
"
res
+=
"Queue Status
\n
"
res
+=
"----------------------------
\n
"
res
+=
"key
\t
free
\t
pending
\n
"
res
+=
"----------------------------
\n
"
for
k
,
v
in
data
[
"queue_info"
]
.
items
():
res
+=
"
%
s
\t
%
d
\t
%
g
\n
"
%
(
k
,
v
[
"free"
],
v
[
"pending"
])
res
+=
"----------------------------
\n
"
return
res
def
request
(
self
,
key
,
priority
=
1
,
session_timeout
=
0
,
max_retry
=
5
):
"""Request a new connection from the tracker.
"""Request a new connection from the tracker.
Parameters
Parameters
...
@@ -184,8 +215,12 @@ class TrackerSession(object):
...
@@ -184,8 +215,12 @@ class TrackerSession(object):
The duration of the session, allows server to kill
The duration of the session, allows server to kill
the connection when duration is longer than this value.
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
When duration is zero, it means the request must always be kept alive.
max_retry : int, optional
Maximum number of times to retry before give up.
"""
"""
for
_
in
range
(
self
.
_max_request_retry
):
last_err
=
None
for
_
in
range
(
max_retry
):
try
:
try
:
if
self
.
_sock
is
None
:
if
self
.
_sock
is
None
:
self
.
_connect
()
self
.
_connect
()
...
@@ -196,10 +231,63 @@ class TrackerSession(object):
...
@@ -196,10 +231,63 @@ class TrackerSession(object):
raise
RuntimeError
(
"Invalid return value
%
s"
%
str
(
value
))
raise
RuntimeError
(
"Invalid return value
%
s"
%
str
(
value
))
url
,
port
,
matchkey
=
value
[
1
]
url
,
port
,
matchkey
=
value
[
1
]
return
connect
(
url
,
port
,
key
+
matchkey
,
session_timeout
)
return
connect
(
url
,
port
,
key
+
matchkey
,
session_timeout
)
except
socket
.
error
:
except
socket
.
error
as
err
:
self
.
close
()
self
.
close
()
except
TVMError
:
last_err
=
err
pass
except
TVMError
as
err
:
last_err
=
err
raise
RuntimeError
(
"Cannot request
%
s after
%
d retry, last_error:
%
s"
%
(
key
,
max_retry
,
str
(
last_err
)))
def
request_and_run
(
self
,
key
,
func
,
priority
=
1
,
session_timeout
=
0
,
max_retry
=
2
):
"""Request a resource from tracker and run the func.
This function safe-guard rare server node dropout during execution.
In such case, a new resource will be requested and func will be ran again.
Parameters
----------
key : str
The type key of the device.
func : function of session -> value
A stateless function
priority : int, optional
The priority of the request.
session_timeout : float, optional
The duration of the session, allows server to kill
the connection when duration is longer than this value.
When duration is zero, it means the request must always be kept alive.
max_retry : int, optional
Maximum number of times to retry the function before give up.
"""
last_err
=
None
for
_
in
range
(
max_retry
):
try
:
sess
=
self
.
request
(
key
,
priority
=
priority
,
session_timeout
=
session_timeout
)
tstart
=
time
.
time
()
return
func
(
sess
)
except
TVMError
as
err
:
duration
=
time
.
time
()
-
tstart
# roughly estimate if the error is due to timeout termination
if
session_timeout
and
duration
>=
session_timeout
*
0.95
:
raise
RuntimeError
(
"Session timeout when running
%
s"
%
func
.
__name__
)
last_err
=
err
raise
RuntimeError
(
"Failed to run on
%
s after
%
d retry, last_error:
%
s"
%
(
key
,
max_retry
,
str
(
last_err
)))
def
connect
(
url
,
port
,
key
=
""
,
session_timeout
=
0
):
def
connect
(
url
,
port
,
key
=
""
,
session_timeout
=
0
):
...
...
python/tvm/contrib/rpc/server.py
View file @
79fc6672
...
@@ -137,6 +137,11 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -137,6 +137,11 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
try
:
try
:
# step 2: wait for in-coming connections
# step 2: wait for in-coming connections
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
...
...
python/tvm/contrib/rpc/tracker.py
View file @
79fc6672
...
@@ -75,10 +75,15 @@ class Scheduler(object):
...
@@ -75,10 +75,15 @@ class Scheduler(object):
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
summary
(
self
):
"""Get summary information of the scheduler."""
raise
NotImplementedError
()
class
PriorityScheduler
(
Scheduler
):
class
PriorityScheduler
(
Scheduler
):
"""Priority based scheduler, FIFO based on time"""
"""Priority based scheduler, FIFO based on time"""
def
__init__
(
self
):
def
__init__
(
self
,
key
):
self
.
_key
=
key
self
.
_values
=
[]
self
.
_values
=
[]
self
.
_requests
=
[]
self
.
_requests
=
[]
...
@@ -98,6 +103,11 @@ class PriorityScheduler(Scheduler):
...
@@ -98,6 +103,11 @@ class PriorityScheduler(Scheduler):
heapq
.
heappush
(
self
.
_requests
,
(
-
priority
,
time
.
time
(),
callback
))
heapq
.
heappush
(
self
.
_requests
,
(
-
priority
,
time
.
time
(),
callback
))
self
.
_schedule
()
self
.
_schedule
()
def
summary
(
self
):
"""Get summary information of the scheduler."""
return
{
"free"
:
len
(
self
.
_values
),
"pending"
:
len
(
self
.
_requests
)}
class
TCPEventHandler
(
tornado_util
.
TCPHandler
):
class
TCPEventHandler
(
tornado_util
.
TCPHandler
):
"""Base asynchronize message handler.
"""Base asynchronize message handler.
...
@@ -113,12 +123,17 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -113,12 +123,17 @@ class TCPEventHandler(tornado_util.TCPHandler):
self
.
_msg_size
=
0
self
.
_msg_size
=
0
self
.
_addr
=
addr
self
.
_addr
=
addr
self
.
_init_req_nbytes
=
4
self
.
_init_req_nbytes
=
4
self
.
_info
=
{
"addr"
:
addr
}
self
.
_tracker
.
_connections
.
add
(
self
)
self
.
_tracker
.
_connections
.
add
(
self
)
def
name
(
self
):
def
name
(
self
):
"""name of connection"""
"""name of connection"""
return
"TCPSocket:
%
s"
%
str
(
self
.
_addr
)
return
"TCPSocket:
%
s"
%
str
(
self
.
_addr
)
def
summary
(
self
):
"""Summary of this connection"""
return
self
.
_info
def
_init_conn
(
self
,
message
):
def
_init_conn
(
self
,
message
):
"""Initialie the connection"""
"""Initialie the connection"""
if
len
(
message
)
!=
4
:
if
len
(
message
)
!=
4
:
...
@@ -193,6 +208,12 @@ class TCPEventHandler(tornado_util.TCPHandler):
...
@@ -193,6 +208,12 @@ class TCPEventHandler(tornado_util.TCPHandler):
self
.
_tracker
.
stop
()
self
.
_tracker
.
stop
()
else
:
else
:
self
.
ret_value
(
TrackerCode
.
FAIL
)
self
.
ret_value
(
TrackerCode
.
FAIL
)
elif
code
==
TrackerCode
.
UPDATE_INFO
:
self
.
_info
.
update
(
args
[
1
])
self
.
ret_value
(
TrackerCode
.
SUCCESS
)
elif
code
==
TrackerCode
.
SUMMARY
:
status
=
self
.
_tracker
.
summary
()
self
.
ret_value
([
TrackerCode
.
SUCCESS
,
status
])
else
:
else
:
logging
.
info
(
"Unknown tracker code
%
d"
,
code
)
logging
.
info
(
"Unknown tracker code
%
d"
,
code
)
self
.
close
()
self
.
close
()
...
@@ -230,8 +251,7 @@ class TrackerServerHandler(object):
...
@@ -230,8 +251,7 @@ class TrackerServerHandler(object):
def
create_scheduler
(
self
,
key
):
def
create_scheduler
(
self
,
key
):
"""Create a new scheduler."""
"""Create a new scheduler."""
_
=
key
return
PriorityScheduler
(
key
)
return
PriorityScheduler
()
def
put
(
self
,
key
,
value
):
def
put
(
self
,
key
,
value
):
"""Report a new resource to the tracker."""
"""Report a new resource to the tracker."""
...
@@ -252,6 +272,19 @@ class TrackerServerHandler(object):
...
@@ -252,6 +272,19 @@ class TrackerServerHandler(object):
self
.
_sock
.
close
()
self
.
_sock
.
close
()
self
.
_ioloop
.
stop
()
self
.
_ioloop
.
stop
()
def
summary
(
self
):
"""Return a dict summarizing current status."""
qinfo
=
{}
for
k
,
v
in
self
.
_scheduler_map
.
items
():
qinfo
[
k
]
=
v
.
summary
()
cinfo
=
[]
# ignore client connections without key
for
conn
in
self
.
_connections
:
res
=
conn
.
summary
()
if
res
.
get
(
"key"
,
""
)
.
startswith
(
"server"
):
cinfo
.
append
(
res
)
return
{
"queue_info"
:
qinfo
,
"server_info"
:
cinfo
}
def
run
(
self
):
def
run
(
self
):
"""Run the tracker server"""
"""Run the tracker server"""
self
.
_ioloop
.
start
()
self
.
_ioloop
.
start
()
...
...
python/tvm/exec/query_rpc_tracker.py
0 → 100644
View file @
79fc6672
"""Tool to query RPC tracker status"""
from
__future__
import
absolute_import
import
logging
import
argparse
import
os
from
..contrib
import
rpc
def
main
():
"""Main funciton"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
""
,
help
=
'the hostname of the tracker'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
None
,
help
=
'The port of the PRC'
)
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
# default to local host or environment variable
if
not
args
.
host
:
args
.
host
=
os
.
environ
.
get
(
"TVM_TRACKER_HOST"
,
"localhost"
)
if
not
args
.
port
:
args
.
port
=
int
(
os
.
environ
.
get
(
"TVM_TRACKER_PORT"
,
"9190"
))
conn
=
rpc
.
connect_tracker
(
args
.
host
,
args
.
port
)
# pylint: disable=superfluous-parens
print
(
"Tracker address
%
s:
%
d
\n
"
%
(
args
.
host
,
args
.
port
))
print
(
"
%
s"
%
conn
.
text_summary
())
if
__name__
==
"__main__"
:
main
()
python/tvm/exec/rpc_tracker.py
View file @
79fc6672
"""
RPC web proxy, allows redirect to websocket based RPC servers(browsers)
"""
"""
Tool to start RPC tracker
"""
from
__future__
import
absolute_import
from
__future__
import
absolute_import
import
logging
import
logging
...
@@ -9,7 +9,7 @@ def main():
...
@@ -9,7 +9,7 @@ def main():
"""Main funciton"""
"""Main funciton"""
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
"0.0.0.0"
,
parser
.
add_argument
(
'--host'
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
'the hostname of the
serv
er'
)
help
=
'the hostname of the
track
er'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9190
,
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9190
,
help
=
'The port of the PRC'
)
help
=
'The port of the PRC'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
tests/python/contrib/test_rpc_tracker.py
View file @
79fc6672
...
@@ -38,6 +38,15 @@ def check_server_drop():
...
@@ -38,6 +38,15 @@ def check_server_drop():
# Fault tolerence server timeout
# Fault tolerence server timeout
def
check_timeout
(
timeout
,
sleeptime
):
def
check_timeout
(
timeout
,
sleeptime
):
def
myfunc
(
remote
):
time
.
sleep
(
sleeptime
)
f1
=
remote
.
get_function
(
"rpc.test2.addone"
)
assert
f1
(
10
)
==
11
try
:
tclient
.
request_and_run
(
"xyz"
,
myfunc
,
session_timeout
=
timeout
)
except
RuntimeError
:
pass
print
(
tclient
.
text_summary
())
try
:
try
:
remote
=
tclient
.
request
(
"xyz"
,
priority
=
0
,
session_timeout
=
timeout
)
remote
=
tclient
.
request
(
"xyz"
,
priority
=
0
,
session_timeout
=
timeout
)
remote2
=
tclient
.
request
(
"xyz"
,
session_timeout
=
timeout
)
remote2
=
tclient
.
request
(
"xyz"
,
session_timeout
=
timeout
)
...
@@ -48,8 +57,11 @@ def check_server_drop():
...
@@ -48,8 +57,11 @@ def check_server_drop():
assert
f1
(
10
)
==
11
assert
f1
(
10
)
==
11
except
tvm
.
TVMError
as
e
:
except
tvm
.
TVMError
as
e
:
pass
pass
check_timeout
(
0.01
,
0.1
)
check_timeout
(
0.01
,
0.1
)
check_timeout
(
2
,
0
)
check_timeout
(
2
,
0
)
except
ImportError
:
except
ImportError
:
print
(
"Skip because tornado is not available"
)
print
(
"Skip because tornado is not available"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment