Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
14181340
Commit
14181340
authored
Apr 06, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 06, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RPC] More robust tracker protocol (#1085)
* [RPC] More robust tracker protocol * fix normal rpc
parent
2e17e850
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
113 additions
and
28 deletions
+113
-28
python/tvm/contrib/rpc/base.py
+1
-0
python/tvm/contrib/rpc/client.py
+1
-1
python/tvm/contrib/rpc/proxy.py
+51
-16
python/tvm/contrib/rpc/server.py
+25
-7
python/tvm/contrib/rpc/tornado_util.py
+2
-0
python/tvm/contrib/rpc/tracker.py
+16
-3
tests/python/contrib/test_rpc_tracker.py
+17
-1
No files found.
python/tvm/contrib/rpc/base.py
View file @
14181340
...
...
@@ -34,6 +34,7 @@ class TrackerCode(object):
REQUEST
=
4
UPDATE_INFO
=
5
SUMMARY
=
6
GET_PENDING_MATCHKEYS
=
7
RPC_SESS_MASK
=
128
...
...
python/tvm/contrib/rpc/client.py
View file @
14181340
...
...
@@ -230,7 +230,7 @@ class TrackerSession(object):
if
value
[
0
]
!=
base
.
TrackerCode
.
SUCCESS
:
raise
RuntimeError
(
"Invalid return value
%
s"
%
str
(
value
))
url
,
port
,
matchkey
=
value
[
1
]
return
connect
(
url
,
port
,
key
+
matchkey
,
session_timeout
)
return
connect
(
url
,
port
,
matchkey
,
session_timeout
)
except
socket
.
error
as
err
:
self
.
close
()
last_err
=
err
...
...
python/tvm/contrib/rpc/proxy.py
View file @
14181340
...
...
@@ -14,6 +14,7 @@ import socket
import
multiprocessing
import
errno
import
struct
import
time
try
:
import
tornado
...
...
@@ -45,6 +46,7 @@ class ForwardHandler(object):
self
.
rpc_key
=
None
self
.
match_key
=
None
self
.
forward_proxy
=
None
self
.
alloc_time
=
None
def
__del__
(
self
):
logging
.
info
(
"Delete
%
s..."
,
self
.
name
())
...
...
@@ -237,6 +239,7 @@ class ProxyServerHandler(object):
self
.
sock
.
fileno
(),
event_handler
,
self
.
loop
.
READ
)
self
.
_client_pool
=
{}
self
.
_server_pool
=
{}
self
.
timeout_alloc
=
5
self
.
timeout_client
=
timeout_client
self
.
timeout_server
=
timeout_server
# tracker information
...
...
@@ -245,8 +248,12 @@ class ProxyServerHandler(object):
self
.
_tracker_conn
=
None
self
.
_tracker_pending_puts
=
[]
self
.
_key_set
=
set
()
self
.
update_tracker_period
=
2
if
tracker_addr
:
logging
.
info
(
"Tracker address:
%
s"
,
str
(
tracker_addr
))
def
_callback
():
self
.
_update_tracker
(
True
)
self
.
loop
.
call_later
(
self
.
update_tracker_period
,
_callback
)
logging
.
info
(
"RPCProxy: Websock port bind to
%
d"
,
web_port
)
def
_on_event
(
self
,
_
):
...
...
@@ -271,7 +278,22 @@ class ProxyServerHandler(object):
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
def
_update_tracker
(
self
):
def
_regenerate_server_keys
(
self
,
keys
):
"""Regenerate keys for server pool"""
keyset
=
set
(
self
.
_server_pool
.
keys
())
new_keys
=
[]
# re-generate the server match key, so old information is invalidated.
for
key
in
keys
:
rpc_key
,
_
=
key
.
split
(
":"
)
handle
=
self
.
_server_pool
[
key
]
del
self
.
_server_pool
[
key
]
new_key
=
base
.
random_key
(
rpc_key
+
":"
,
keyset
)
self
.
_server_pool
[
new_key
]
=
handle
keyset
.
add
(
new_key
)
new_keys
.
append
(
new_key
)
return
new_keys
def
_update_tracker
(
self
,
period_update
=
False
):
"""Update information on tracker."""
try
:
if
self
.
_tracker_conn
is
None
:
...
...
@@ -285,13 +307,33 @@ class ProxyServerHandler(object):
# just connect to tracker, need to update all keys
self
.
_tracker_pending_puts
=
self
.
_server_pool
.
keys
()
if
self
.
_tracker_conn
and
period_update
:
# periodically update tracker information
# regenerate key if the key is not in tracker anymore
# and there is no in-coming connection after timeout_alloc
base
.
sendjson
(
self
.
_tracker_conn
,
[
TrackerCode
.
GET_PENDING_MATCHKEYS
])
pending_keys
=
set
(
base
.
recvjson
(
self
.
_tracker_conn
))
update_keys
=
[]
for
k
,
v
in
self
.
_server_pool
.
items
():
if
k
not
in
pending_keys
:
if
v
.
alloc_time
is
None
:
v
.
alloc_time
=
time
.
time
()
elif
time
.
time
()
-
v
.
alloc_time
>
self
.
timeout_alloc
:
update_keys
.
append
(
k
)
v
.
alloc_time
=
None
if
update_keys
:
logging
.
info
(
"RPCProxy: No incoming conn on
%
s, regenerate keys..."
,
str
(
update_keys
))
new_keys
=
self
.
_regenerate_server_keys
(
update_keys
)
self
.
_tracker_pending_puts
+=
new_keys
need_update_info
=
False
# report new connections
for
key
in
self
.
_tracker_pending_puts
:
rpc_key
,
match_key
=
key
.
split
(
":"
)
rpc_key
=
key
.
split
(
":"
)[
0
]
base
.
sendjson
(
self
.
_tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
self
.
_listen_port
,
":"
+
match_
key
)])
(
self
.
_listen_port
,
key
)])
assert
base
.
recvjson
(
self
.
_tracker_conn
)
==
TrackerCode
.
SUCCESS
if
rpc_key
not
in
self
.
_key_set
:
self
.
_key_set
.
add
(
rpc_key
)
...
...
@@ -305,24 +347,17 @@ class ProxyServerHandler(object):
assert
base
.
recvjson
(
self
.
_tracker_conn
)
==
TrackerCode
.
SUCCESS
self
.
_tracker_pending_puts
=
[]
except
(
socket
.
error
,
IOError
)
as
err
:
retry_period
=
5
logging
.
info
(
"Lost tracker connection:
%
s, try reconnect in
%
g sec"
,
str
(
err
),
retry
_period
)
str
(
err
),
self
.
update_tracker
_period
)
self
.
_tracker_conn
.
close
()
self
.
_tracker_conn
=
None
new_pool
=
{}
keyset
=
set
(
self
.
_server_pool
.
keys
())
# re-generate the server match key, so old information is invalidated.
for
key
,
handle
in
self
.
_server_pool
.
items
():
rpc_key
,
_
=
key
.
split
(
":"
)
key
=
base
.
random_key
(
rpc_key
+
":"
,
keyset
)
new_pool
[
key
]
=
handle
keyset
.
add
(
key
)
self
.
_server_pool
=
new_pool
self
.
_regenerate_server_keys
(
self
.
_server_pool
.
keys
())
if
period_update
:
def
_callback
():
self
.
_update_tracker
()
self
.
loop
.
call_later
(
retry
_period
,
_callback
)
self
.
_update_tracker
(
True
)
self
.
loop
.
call_later
(
self
.
update_tracker
_period
,
_callback
)
def
_handler_ready_tracker_mode
(
self
,
handler
):
"""tracker mode to handle handler ready."""
...
...
python/tvm/contrib/rpc/server.py
View file @
14181340
...
...
@@ -6,7 +6,7 @@ Server is TCP based with the following protocol:
- Initial handshake to the peer
- [RPC_MAGIC, keysize(int32), key-bytes]
- The key is in format
- {server|client}:device-type[:
match
key] [-timeout=timeout]
- {server|client}:device-type[:
random-
key] [-timeout=timeout]
"""
from
__future__
import
absolute_import
...
...
@@ -75,7 +75,7 @@ def _parse_server_opt(opts):
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
):
"""Lisenting loop of the server master."""
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
0.1
):
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
"""Accept connection from the other places.
Parameters
...
...
@@ -89,22 +89,40 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
ping_period : float, optional
ping tracker every k seconds if no connection is accepted.
"""
old_keyset
=
set
()
# Report resource to tracker
if
tracker_conn
:
matchkey
=
base
.
random_key
(
":"
)
matchkey
=
base
.
random_key
(
rpc_key
+
":"
)
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
else
:
matchkey
=
""
matchkey
=
rpc_key
unmatch_period_count
=
0
unmatch_timeout
=
4
# Wait until we get a valid connection
while
True
:
if
tracker_conn
:
trigger
=
select
.
select
([
listen_sock
],
[],
[],
ping_period
)
if
not
listen_sock
in
trigger
[
0
]:
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PING
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
GET_PENDING_MATCHKEYS
])
pending_keys
=
base
.
recvjson
(
tracker_conn
)
old_keyset
.
add
(
matchkey
)
# if match key not in pending key set
# it means the key is aqquired by a client but not used.
if
matchkey
not
in
pending_keys
:
unmatch_period_count
+=
1
else
:
unmatch_period_count
=
0
# regenerate match key if key is aqquired but not used for a while
if
unmatch_period_count
*
ping_period
>
unmatch_timeout
+
ping_period
:
logging
.
info
(
"RPCServer: no incoming connections, regenerate key ..."
)
matchkey
=
base
.
random_key
(
rpc_key
+
":"
,
old_keyset
)
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
unmatch_period_count
=
0
continue
conn
,
addr
=
listen_sock
.
accept
()
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
conn
,
4
))[
0
]
...
...
@@ -114,7 +132,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
conn
,
4
))[
0
]
key
=
py_str
(
base
.
recvall
(
conn
,
keylen
))
arr
=
key
.
split
()
expect_header
=
"client:"
+
rpc_key
+
matchkey
expect_header
=
"client:"
+
matchkey
server_key
=
"server:"
+
rpc_key
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_CODE_MISMATCH
))
...
...
python/tvm/contrib/rpc/tornado_util.py
View file @
14181340
...
...
@@ -48,6 +48,8 @@ class TCPHandler(object):
def
write_message
(
self
,
message
,
binary
=
True
):
assert
binary
if
self
.
_sock
is
None
:
raise
IOError
(
"socket is already closed"
)
self
.
_pending_write
.
append
(
message
)
self
.
_update_write
()
...
...
python/tvm/contrib/rpc/tracker.py
View file @
14181340
...
...
@@ -92,7 +92,9 @@ class PriorityScheduler(Scheduler):
value
=
self
.
_values
.
pop
(
0
)
item
=
heapq
.
heappop
(
self
.
_requests
)
callback
=
item
[
-
1
]
if
not
callback
(
value
):
if
callback
(
value
[
1
:]):
value
[
0
]
.
pending_matchkeys
.
remove
(
value
[
-
1
])
else
:
self
.
_values
.
append
(
value
)
def
put
(
self
,
value
):
...
...
@@ -124,6 +126,8 @@ class TCPEventHandler(tornado_util.TCPHandler):
self
.
_addr
=
addr
self
.
_init_req_nbytes
=
4
self
.
_info
=
{
"addr"
:
addr
}
# list of pending match keys that has not been used.
self
.
pending_matchkeys
=
set
()
self
.
_tracker
.
_connections
.
add
(
self
)
def
name
(
self
):
...
...
@@ -189,18 +193,27 @@ class TCPEventHandler(tornado_util.TCPHandler):
if
code
==
TrackerCode
.
PUT
:
key
=
args
[
1
]
port
,
matchkey
=
args
[
2
]
self
.
_tracker
.
put
(
key
,
(
self
.
_addr
[
0
],
port
,
matchkey
))
self
.
pending_matchkeys
.
add
(
matchkey
)
self
.
_tracker
.
put
(
key
,
(
self
,
self
.
_addr
[
0
],
port
,
matchkey
))
self
.
ret_value
(
TrackerCode
.
SUCCESS
)
elif
code
==
TrackerCode
.
REQUEST
:
key
=
args
[
1
]
user
=
args
[
2
]
priority
=
args
[
3
]
def
_cb
(
value
):
self
.
ret_value
([
TrackerCode
.
SUCCESS
,
value
])
# if the connection is already closed
if
not
self
.
_sock
:
return
False
try
:
self
.
ret_value
([
TrackerCode
.
SUCCESS
,
value
])
except
(
socket
.
sock_error
,
IOError
):
return
False
return
True
self
.
_tracker
.
request
(
key
,
user
,
priority
,
_cb
)
elif
code
==
TrackerCode
.
PING
:
self
.
ret_value
(
TrackerCode
.
SUCCESS
)
elif
code
==
TrackerCode
.
GET_PENDING_MATCHKEYS
:
self
.
ret_value
(
list
(
self
.
pending_matchkeys
))
elif
code
==
TrackerCode
.
STOP
:
# safe stop tracker
if
self
.
_tracker
.
_stop_key
==
args
[
1
]:
...
...
tests/python/contrib/test_rpc_tracker.py
View file @
14181340
...
...
@@ -23,6 +23,11 @@ def check_server_drop():
tproxy
=
proxy
.
Proxy
(
"localhost"
,
8881
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
))
tclient
=
rpc
.
connect_tracker
(
"localhost"
,
tserver
.
port
)
server0
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
key
=
"abc"
)
server1
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
...
...
@@ -34,6 +39,10 @@ def check_server_drop():
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
key
=
"xyz1"
)
# Fault tolerence to un-handled requested value
_put
(
tclient
,
[
TrackerCode
.
REQUEST
,
"abc"
,
""
,
1
])
_put
(
tclient
,
[
TrackerCode
.
REQUEST
,
"xyz1"
,
""
,
1
])
# Fault tolerence to stale worker value
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abc"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abcxxx"
)])
...
...
@@ -58,14 +67,21 @@ def check_server_drop():
assert
f1
(
10
)
==
11
f1
=
remote2
.
get_function
(
"rpc.test2.addone"
)
assert
f1
(
10
)
==
11
except
tvm
.
TVMError
as
e
:
pass
remote3
=
tclient
.
request
(
"abc"
)
f1
=
remote3
.
get_function
(
"rpc.test2.addone"
)
remote3
=
tclient
.
request
(
"xyz1"
)
f1
=
remote3
.
get_function
(
"rpc.test2.addone"
)
assert
f1
(
10
)
==
11
check_timeout
(
0.01
,
0.1
)
check_timeout
(
2
,
0
)
tserver
.
terminate
()
server
2
.
terminate
()
server
0
.
terminate
()
server1
.
terminate
()
server2
.
terminate
()
server3
.
terminate
()
tproxy
.
terminate
()
except
ImportError
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment