Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3d67ea17
Commit
3d67ea17
authored
Apr 05, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RPC] support tracker in proxy (#1082)
parent
79fc6672
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
235 additions
and
88 deletions
+235
-88
python/tvm/contrib/rpc/base.py
+23
-3
python/tvm/contrib/rpc/proxy.py
+126
-31
python/tvm/contrib/rpc/server.py
+51
-41
python/tvm/contrib/rpc/tracker.py
+1
-1
python/tvm/exec/rpc_proxy.py
+21
-5
tests/python/contrib/test_rpc_tracker.py
+13
-7
No files found.
python/tvm/contrib/rpc/base.py
View file @
3d67ea17
...
...
@@ -94,9 +94,29 @@ def recvjson(sock):
return
data
def
random_key
():
"""Generate a random key n"""
return
str
(
random
.
random
())
def
random_key
(
prefix
,
cmap
=
None
):
"""Generate a random key
Parameters
----------
prefix : str
The string prefix
cmap : dict
Conflict map
Returns
-------
key : str
The generated random key
"""
if
cmap
:
while
True
:
key
=
prefix
+
str
(
random
.
random
())
if
key
not
in
cmap
:
return
key
else
:
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
...
...
python/tvm/contrib/rpc/proxy.py
View file @
3d67ea17
...
...
@@ -25,22 +25,26 @@ except ImportError as error_msg:
raise
ImportError
(
"RPCProxy module requires tornado package
%
s"
%
error_msg
)
from
.
import
base
from
.base
import
RPC_MAGIC
,
RPC_CODE_DUPLICATE
,
RPC_CODE_SUCCESS
,
RPC_CODE_MISMATCH
from
.base
import
TrackerCode
from
.server
import
_server_env
from
..._ffi.base
import
py_str
class
ForwardHandler
(
object
):
"""Forward handler to forward the message."""
def
_init_handler
(
self
):
"""Initialize handler."""
self
.
_init_message
=
bytes
()
self
.
_init_req_nbytes
=
4
self
.
forward_proxy
=
None
self
.
_magic
=
None
self
.
timeout
=
None
self
.
_rpc_key_length
=
None
self
.
rpc_key
=
None
self
.
_done
=
False
self
.
_proxy
=
ProxyServerHandler
.
current
assert
self
.
_proxy
self
.
rpc_key
=
None
self
.
match_key
=
None
self
.
forward_proxy
=
None
def
__del__
(
self
):
logging
.
info
(
"Delete
%
s..."
,
self
.
name
())
...
...
@@ -53,7 +57,7 @@ class ForwardHandler(object):
if
self
.
_magic
is
None
:
assert
len
(
message
)
==
4
self
.
_magic
=
struct
.
unpack
(
'@i'
,
message
)[
0
]
if
self
.
_magic
!=
RPC_MAGIC
:
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
_init_req_nbytes
=
4
...
...
@@ -64,13 +68,15 @@ class ForwardHandler(object):
elif
self
.
rpc_key
is
None
:
assert
len
(
message
)
==
self
.
_rpc_key_length
self
.
rpc_key
=
py_str
(
message
)
# match key is used to do the matching
self
.
match_key
=
self
.
rpc_key
[
7
:]
.
split
()[
0
]
self
.
on_start
()
else
:
assert
False
def
on_start
(
self
):
"""Event when the initialization is completed"""
ProxyServerHandler
.
current
.
handler_ready
(
self
)
self
.
_proxy
.
handler_ready
(
self
)
def
on_data
(
self
,
message
):
"""on data"""
...
...
@@ -105,12 +111,12 @@ class ForwardHandler(object):
"""on close event"""
assert
not
self
.
_done
logging
.
info
(
"RPCProxy:on_close
%
s ..."
,
self
.
name
())
if
self
.
rpc
_key
:
key
=
self
.
rpc_key
[
7
:]
if
ProxyServerHandler
.
current
.
_client_pool
.
get
(
key
,
None
)
==
self
:
ProxyServerHandler
.
current
.
_client_pool
.
pop
(
key
)
if
ProxyServerHandler
.
current
.
_server_pool
.
get
(
key
,
None
)
==
self
:
ProxyServerHandler
.
current
.
_server_pool
.
pop
(
key
)
if
self
.
match
_key
:
key
=
self
.
match_key
if
self
.
_proxy
.
_client_pool
.
get
(
key
,
None
)
==
self
:
self
.
_proxy
.
_client_pool
.
pop
(
key
)
if
self
.
_proxy
.
_server_pool
.
get
(
key
,
None
)
==
self
:
self
.
_proxy
.
_server_pool
.
pop
(
key
)
self
.
_done
=
True
self
.
forward_proxy
=
None
...
...
@@ -123,7 +129,7 @@ class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
self
.
addr
=
addr
def
name
(
self
):
return
"TCPSocket
:
%
s:
%
s"
%
(
str
(
self
.
addr
),
self
.
rpc_key
)
return
"TCPSocket
Proxy:
%
s:
%
s"
%
(
str
(
self
.
addr
[
0
]
),
self
.
rpc_key
)
def
send_data
(
self
,
message
,
binary
=
True
):
self
.
write_message
(
message
,
True
)
...
...
@@ -146,7 +152,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
self
.
_init_handler
()
def
name
(
self
):
return
"WebSocketProxy:
%
s"
%
(
self
.
rpc_key
)
return
"WebSocketProxy:
%
s"
%
(
self
.
rpc_key
)
def
on_message
(
self
,
message
):
self
.
on_data
(
message
)
...
...
@@ -192,15 +198,16 @@ class RequestHandler(tornado.web.RequestHandler):
self
.
write
(
self
.
page
)
class
ProxyServerHandler
(
object
):
"""Internal proxy server handler class."""
current
=
None
def
__init__
(
self
,
sock
,
listen_port
,
web_port
,
timeout_client
,
timeout_server
,
tracker_addr
,
index_page
=
None
,
resource_files
=
None
):
assert
ProxyServerHandler
.
current
is
None
...
...
@@ -232,6 +239,14 @@ class ProxyServerHandler(object):
self
.
_server_pool
=
{}
self
.
timeout_client
=
timeout_client
self
.
timeout_server
=
timeout_server
# tracker information
self
.
_listen_port
=
listen_port
self
.
_tracker_addr
=
tracker_addr
self
.
_tracker_conn
=
None
self
.
_tracker_pending_puts
=
[]
self
.
_key_set
=
set
()
if
tracker_addr
:
logging
.
info
(
"Tracker address:
%
s"
,
str
(
tracker_addr
))
logging
.
info
(
"RPCProxy: Websock port bind to
%
d"
,
web_port
)
def
_on_event
(
self
,
_
):
...
...
@@ -247,21 +262,85 @@ class ProxyServerHandler(object):
lhs
.
forward_proxy
=
rhs
rhs
.
forward_proxy
=
lhs
lhs
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'@i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'@i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
def
handler_ready
(
self
,
handler
):
"""Report handler to be ready."""
logging
.
info
(
"Handler ready
%
s"
,
handler
.
name
())
key
=
handler
.
rpc_key
[
7
:]
def
_update_tracker
(
self
):
"""Update information on tracker."""
try
:
if
self
.
_tracker_conn
is
None
:
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
self
.
loop
.
stop
()
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
# just connect to tracker, need to update all keys
self
.
_tracker_pending_puts
=
self
.
_server_pool
.
keys
()
need_update_info
=
False
# report new connections
for
key
in
self
.
_tracker_pending_puts
:
rpc_key
,
match_key
=
key
.
split
(
":"
)
base
.
sendjson
(
self
.
_tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
self
.
_listen_port
,
":"
+
match_key
)])
assert
base
.
recvjson
(
self
.
_tracker_conn
)
==
TrackerCode
.
SUCCESS
if
rpc_key
not
in
self
.
_key_set
:
self
.
_key_set
.
add
(
rpc_key
)
need_update_info
=
True
if
need_update_info
:
keylist
=
"["
+
","
.
join
(
self
.
_key_set
)
+
"]"
cinfo
=
{
"key"
:
"server:proxy"
+
keylist
}
base
.
sendjson
(
self
.
_tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
self
.
_tracker_conn
)
==
TrackerCode
.
SUCCESS
self
.
_tracker_pending_puts
=
[]
except
(
socket
.
error
,
IOError
)
as
err
:
retry_period
=
5
logging
.
info
(
"Lost tracker connection:
%
s, try reconnect in
%
g sec"
,
str
(
err
),
retry_period
)
self
.
_tracker_conn
.
close
()
self
.
_tracker_conn
=
None
new_pool
=
{}
keyset
=
set
(
self
.
_server_pool
.
keys
())
# re-generate the server match key, so old information is invalidated.
for
key
,
handle
in
self
.
_server_pool
.
items
():
rpc_key
,
_
=
key
.
split
(
":"
)
key
=
base
.
random_key
(
rpc_key
+
":"
,
keyset
)
new_pool
[
key
]
=
handle
keyset
.
add
(
key
)
self
.
_server_pool
=
new_pool
def
_callback
():
self
.
_update_tracker
()
self
.
loop
.
call_later
(
retry_period
,
_callback
)
def
_handler_ready_tracker_mode
(
self
,
handler
):
"""tracker mode to handle handler ready."""
if
handler
.
rpc_key
.
startswith
(
"server:"
):
key
=
base
.
random_key
(
handler
.
match_key
+
":"
,
self
.
_server_pool
)
handler
.
match_key
=
key
self
.
_server_pool
[
key
]
=
handler
self
.
_tracker_pending_puts
.
append
(
key
)
self
.
_update_tracker
()
else
:
if
handler
.
match_key
in
self
.
_server_pool
:
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
else
:
handler
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
def
_handler_ready_proxy_mode
(
self
,
handler
):
"""Normal proxy mode when handler is ready."""
if
handler
.
rpc_key
.
startswith
(
"server:"
):
pool_src
,
pool_dst
=
self
.
_client_pool
,
self
.
_server_pool
timeout
=
self
.
timeout_server
...
...
@@ -269,6 +348,7 @@ class ProxyServerHandler(object):
pool_src
,
pool_dst
=
self
.
_server_pool
,
self
.
_client_pool
timeout
=
self
.
timeout_client
key
=
handler
.
match_key
if
key
in
pool_src
:
self
.
_pair_up
(
pool_src
.
pop
(
key
),
handler
)
return
...
...
@@ -280,29 +360,41 @@ class ProxyServerHandler(object):
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
handler
.
name
(),
key
)
pool_dst
.
pop
(
key
)
handler
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
self
.
loop
.
call_later
(
timeout
,
cleanup
)
else
:
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
handler
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_DUPLICATE
))
handler
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
signal_close
()
def
handler_ready
(
self
,
handler
):
"""Report handler to be ready."""
logging
.
info
(
"Handler ready
%
s"
,
handler
.
name
())
if
self
.
_tracker_addr
:
self
.
_handler_ready_tracker_mode
(
handler
)
else
:
self
.
_handler_ready_proxy_mode
(
handler
)
def
run
(
self
):
"""Run the proxy server"""
ioloop
.
IOLoop
.
current
()
.
start
()
def
_proxy_server
(
listen_sock
,
listen_port
,
web_port
,
timeout_client
,
timeout_server
,
tracker_addr
,
index_page
,
resource_files
):
handler
=
ProxyServerHandler
(
listen_sock
,
listen_port
,
web_port
,
timeout_client
,
timeout_server
,
tracker_addr
,
index_page
,
resource_files
)
handler
.
run
()
...
...
@@ -346,6 +438,7 @@ class Proxy(object):
web_port
=
0
,
timeout_client
=
600
,
timeout_server
=
600
,
tracker_addr
=
None
,
index_page
=
None
,
resource_files
=
None
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
...
...
@@ -365,10 +458,12 @@ class Proxy(object):
logging
.
info
(
"RPCProxy: client port bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
self
.
proc
=
multiprocessing
.
Process
(
target
=
_proxy_server
,
args
=
(
sock
,
web_port
,
timeout_client
,
timeout_server
,
index_page
,
resource_files
))
target
=
_proxy_server
,
args
=
(
sock
,
self
.
port
,
web_port
,
timeout_client
,
timeout_server
,
tracker_addr
,
index_page
,
resource_files
))
self
.
proc
.
start
()
sock
.
close
()
self
.
host
=
host
def
terminate
(
self
):
...
...
@@ -408,18 +503,18 @@ def websocket_proxy_server(url, key=""):
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
()
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'@i'
,
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'@i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
conn
.
write_message
(
struct
.
pack
(
'@i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
msg
=
yield
conn
.
read_message
()
assert
len
(
msg
)
>=
4
magic
=
struct
.
unpack
(
'@i'
,
msg
[:
4
])[
0
]
if
magic
==
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
RPC_CODE_SUCCESS
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
url
)
msg
=
msg
[
4
:]
...
...
python/tvm/contrib/rpc/server.py
View file @
3d67ea17
...
...
@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
"""
# Report resource to tracker
if
tracker_conn
:
matchkey
=
":"
+
base
.
random_key
(
)
matchkey
=
base
.
random_key
(
":"
)
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
...
...
@@ -130,19 +130,20 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# Server logic
tracker_conn
=
None
while
True
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
try
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
# step 2: wait for in-coming connections
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
except
(
socket
.
error
,
IOError
):
...
...
@@ -161,40 +162,49 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# wait until server process finish or timeout
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
server_proc
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
logging
.
info
(
"
RPCServer:
Timeout in RPC session, kill.."
)
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
):
key
=
"server:"
+
key
retry_count
=
0
max_retry
=
5
retry_period
=
5
while
True
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
process
.
deamon
=
True
process
.
start
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
process
.
terminate
()
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
process
.
deamon
=
True
process
.
start
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
logging
.
info
(
"RPCProxyServer: Timeout in RPC session, kill.."
)
process
.
terminate
()
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
logging
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
def
_popen
(
cmd
):
proc
=
subprocess
.
Popen
(
cmd
,
...
...
python/tvm/contrib/rpc/tracker.py
View file @
3d67ea17
...
...
@@ -317,7 +317,7 @@ class Tracker(object):
port_end
=
9199
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
stop_key
=
base
.
random_key
()
self
.
stop_key
=
base
.
random_key
(
"tracker"
)
for
my_port
in
range
(
port
,
port_end
):
try
:
sock
.
bind
((
host
,
my_port
))
...
...
python/tvm/exec/rpc_proxy.py
View file @
3d67ea17
...
...
@@ -29,19 +29,35 @@ def main():
help
=
'the hostname of the server'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9090
,
help
=
'The port of the PRC'
)
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
9190
,
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
8888
,
help
=
'The port of the http/websocket server'
)
parser
.
add_argument
(
'--example-rpc'
,
type
=
bool
,
default
=
False
,
help
=
'Whether to switch on example rpc mode'
)
parser
.
add_argument
(
'--tracker'
,
type
=
str
,
default
=
""
,
help
=
"Report to RPC tracker"
)
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
if
args
.
tracker
:
url
,
port
=
args
.
tracker
.
split
(
":"
)
port
=
int
(
port
)
tracker_addr
=
(
url
,
port
)
else
:
tracker_addr
=
None
if
args
.
example_rpc
:
index
,
js_files
=
find_example_resource
()
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
index_page
=
index
,
resource_files
=
js_files
)
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
index_page
=
index
,
resource_files
=
js_files
,
tracker_addr
=
tracker_addr
)
else
:
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
)
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
tracker_addr
=
tracker_addr
)
prox
.
proc
.
join
()
if
__name__
==
"__main__"
:
...
...
tests/python/contrib/test_rpc_tracker.py
View file @
3d67ea17
...
...
@@ -8,7 +8,7 @@ from tvm.contrib import rpc
def
check_server_drop
():
"""test when server drops"""
try
:
from
tvm.contrib.rpc
import
tracker
,
base
from
tvm.contrib.rpc
import
tracker
,
proxy
,
base
from
tvm.contrib.rpc.base
import
TrackerCode
@tvm.register_func
(
"rpc.test2.addone"
)
...
...
@@ -20,21 +20,24 @@ def check_server_drop():
base
.
recvjson
(
tclient
.
_sock
)
tserver
=
tracker
.
Tracker
(
"localhost"
,
8888
)
tproxy
=
proxy
.
Proxy
(
"localhost"
,
8881
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
))
tclient
=
rpc
.
connect_tracker
(
"localhost"
,
tserver
.
port
)
server1
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
key
=
"xyz"
)
server2
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
key
=
"xyz"
)
server3
=
rpc
.
Server
(
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
key
=
"xyz1"
)
# Fault tolerence to stale worker value
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abc"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abcxxx"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server2
.
port
,
"abcxxx11"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
tproxy
.
port
,
"abcxxx11"
)])
# Fault tolerence server timeout
def
check_timeout
(
timeout
,
sleeptime
):
...
...
@@ -60,8 +63,11 @@ def check_server_drop():
check_timeout
(
0.01
,
0.1
)
check_timeout
(
2
,
0
)
tserver
.
terminate
()
server2
.
terminate
()
server1
.
terminate
()
server3
.
terminate
()
tproxy
.
terminate
()
except
ImportError
:
print
(
"Skip because tornado is not available"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment