Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3d67ea17
Commit
3d67ea17
authored
Apr 05, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RPC] support tracker in proxy (#1082)
parent
79fc6672
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
235 additions
and
88 deletions
+235
-88
python/tvm/contrib/rpc/base.py
+23
-3
python/tvm/contrib/rpc/proxy.py
+126
-31
python/tvm/contrib/rpc/server.py
+51
-41
python/tvm/contrib/rpc/tracker.py
+1
-1
python/tvm/exec/rpc_proxy.py
+21
-5
tests/python/contrib/test_rpc_tracker.py
+13
-7
No files found.
python/tvm/contrib/rpc/base.py
View file @
3d67ea17
...
@@ -94,9 +94,29 @@ def recvjson(sock):
...
@@ -94,9 +94,29 @@ def recvjson(sock):
return
data
return
data
def
random_key
():
def
random_key
(
prefix
,
cmap
=
None
):
"""Generate a random key n"""
"""Generate a random key
return
str
(
random
.
random
())
Parameters
----------
prefix : str
The string prefix
cmap : dict
Conflict map
Returns
-------
key : str
The generated random key
"""
if
cmap
:
while
True
:
key
=
prefix
+
str
(
random
.
random
())
if
key
not
in
cmap
:
return
key
else
:
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
...
...
python/tvm/contrib/rpc/proxy.py
View file @
3d67ea17
...
@@ -25,22 +25,26 @@ except ImportError as error_msg:
...
@@ -25,22 +25,26 @@ except ImportError as error_msg:
raise
ImportError
(
"RPCProxy module requires tornado package
%
s"
%
error_msg
)
raise
ImportError
(
"RPCProxy module requires tornado package
%
s"
%
error_msg
)
from
.
import
base
from
.
import
base
from
.base
import
RPC_MAGIC
,
RPC_CODE_DUPLICATE
,
RPC_CODE_SUCCESS
,
RPC_CODE_MISMATCH
from
.base
import
TrackerCode
from
.server
import
_server_env
from
.server
import
_server_env
from
..._ffi.base
import
py_str
from
..._ffi.base
import
py_str
class
ForwardHandler
(
object
):
class
ForwardHandler
(
object
):
"""Forward handler to forward the message."""
"""Forward handler to forward the message."""
def
_init_handler
(
self
):
def
_init_handler
(
self
):
"""Initialize handler."""
"""Initialize handler."""
self
.
_init_message
=
bytes
()
self
.
_init_message
=
bytes
()
self
.
_init_req_nbytes
=
4
self
.
_init_req_nbytes
=
4
self
.
forward_proxy
=
None
self
.
_magic
=
None
self
.
_magic
=
None
self
.
timeout
=
None
self
.
timeout
=
None
self
.
_rpc_key_length
=
None
self
.
_rpc_key_length
=
None
self
.
rpc_key
=
None
self
.
_done
=
False
self
.
_done
=
False
self
.
_proxy
=
ProxyServerHandler
.
current
assert
self
.
_proxy
self
.
rpc_key
=
None
self
.
match_key
=
None
self
.
forward_proxy
=
None
def
__del__
(
self
):
def
__del__
(
self
):
logging
.
info
(
"Delete
%
s..."
,
self
.
name
())
logging
.
info
(
"Delete
%
s..."
,
self
.
name
())
...
@@ -53,7 +57,7 @@ class ForwardHandler(object):
...
@@ -53,7 +57,7 @@ class ForwardHandler(object):
if
self
.
_magic
is
None
:
if
self
.
_magic
is
None
:
assert
len
(
message
)
==
4
assert
len
(
message
)
==
4
self
.
_magic
=
struct
.
unpack
(
'@i'
,
message
)[
0
]
self
.
_magic
=
struct
.
unpack
(
'@i'
,
message
)[
0
]
if
self
.
_magic
!=
RPC_MAGIC
:
if
self
.
_magic
!=
base
.
RPC_MAGIC
:
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
logging
.
info
(
"Invalid RPC magic from
%
s"
,
self
.
name
())
self
.
close
()
self
.
close
()
self
.
_init_req_nbytes
=
4
self
.
_init_req_nbytes
=
4
...
@@ -64,13 +68,15 @@ class ForwardHandler(object):
...
@@ -64,13 +68,15 @@ class ForwardHandler(object):
elif
self
.
rpc_key
is
None
:
elif
self
.
rpc_key
is
None
:
assert
len
(
message
)
==
self
.
_rpc_key_length
assert
len
(
message
)
==
self
.
_rpc_key_length
self
.
rpc_key
=
py_str
(
message
)
self
.
rpc_key
=
py_str
(
message
)
# match key is used to do the matching
self
.
match_key
=
self
.
rpc_key
[
7
:]
.
split
()[
0
]
self
.
on_start
()
self
.
on_start
()
else
:
else
:
assert
False
assert
False
def
on_start
(
self
):
def
on_start
(
self
):
"""Event when the initialization is completed"""
"""Event when the initialization is completed"""
ProxyServerHandler
.
current
.
handler_ready
(
self
)
self
.
_proxy
.
handler_ready
(
self
)
def
on_data
(
self
,
message
):
def
on_data
(
self
,
message
):
"""on data"""
"""on data"""
...
@@ -105,12 +111,12 @@ class ForwardHandler(object):
...
@@ -105,12 +111,12 @@ class ForwardHandler(object):
"""on close event"""
"""on close event"""
assert
not
self
.
_done
assert
not
self
.
_done
logging
.
info
(
"RPCProxy:on_close
%
s ..."
,
self
.
name
())
logging
.
info
(
"RPCProxy:on_close
%
s ..."
,
self
.
name
())
if
self
.
rpc
_key
:
if
self
.
match
_key
:
key
=
self
.
rpc_key
[
7
:]
key
=
self
.
match_key
if
ProxyServerHandler
.
current
.
_client_pool
.
get
(
key
,
None
)
==
self
:
if
self
.
_proxy
.
_client_pool
.
get
(
key
,
None
)
==
self
:
ProxyServerHandler
.
current
.
_client_pool
.
pop
(
key
)
self
.
_proxy
.
_client_pool
.
pop
(
key
)
if
ProxyServerHandler
.
current
.
_server_pool
.
get
(
key
,
None
)
==
self
:
if
self
.
_proxy
.
_server_pool
.
get
(
key
,
None
)
==
self
:
ProxyServerHandler
.
current
.
_server_pool
.
pop
(
key
)
self
.
_proxy
.
_server_pool
.
pop
(
key
)
self
.
_done
=
True
self
.
_done
=
True
self
.
forward_proxy
=
None
self
.
forward_proxy
=
None
...
@@ -123,7 +129,7 @@ class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
...
@@ -123,7 +129,7 @@ class TCPHandler(tornado_util.TCPHandler, ForwardHandler):
self
.
addr
=
addr
self
.
addr
=
addr
def
name
(
self
):
def
name
(
self
):
return
"TCPSocket
:
%
s:
%
s"
%
(
str
(
self
.
addr
),
self
.
rpc_key
)
return
"TCPSocket
Proxy:
%
s:
%
s"
%
(
str
(
self
.
addr
[
0
]
),
self
.
rpc_key
)
def
send_data
(
self
,
message
,
binary
=
True
):
def
send_data
(
self
,
message
,
binary
=
True
):
self
.
write_message
(
message
,
True
)
self
.
write_message
(
message
,
True
)
...
@@ -146,7 +152,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
...
@@ -146,7 +152,7 @@ class WebSocketHandler(websocket.WebSocketHandler, ForwardHandler):
self
.
_init_handler
()
self
.
_init_handler
()
def
name
(
self
):
def
name
(
self
):
return
"WebSocketProxy:
%
s"
%
(
self
.
rpc_key
)
return
"WebSocketProxy:
%
s"
%
(
self
.
rpc_key
)
def
on_message
(
self
,
message
):
def
on_message
(
self
,
message
):
self
.
on_data
(
message
)
self
.
on_data
(
message
)
...
@@ -192,15 +198,16 @@ class RequestHandler(tornado.web.RequestHandler):
...
@@ -192,15 +198,16 @@ class RequestHandler(tornado.web.RequestHandler):
self
.
write
(
self
.
page
)
self
.
write
(
self
.
page
)
class
ProxyServerHandler
(
object
):
class
ProxyServerHandler
(
object
):
"""Internal proxy server handler class."""
"""Internal proxy server handler class."""
current
=
None
current
=
None
def
__init__
(
self
,
def
__init__
(
self
,
sock
,
sock
,
listen_port
,
web_port
,
web_port
,
timeout_client
,
timeout_client
,
timeout_server
,
timeout_server
,
tracker_addr
,
index_page
=
None
,
index_page
=
None
,
resource_files
=
None
):
resource_files
=
None
):
assert
ProxyServerHandler
.
current
is
None
assert
ProxyServerHandler
.
current
is
None
...
@@ -232,6 +239,14 @@ class ProxyServerHandler(object):
...
@@ -232,6 +239,14 @@ class ProxyServerHandler(object):
self
.
_server_pool
=
{}
self
.
_server_pool
=
{}
self
.
timeout_client
=
timeout_client
self
.
timeout_client
=
timeout_client
self
.
timeout_server
=
timeout_server
self
.
timeout_server
=
timeout_server
# tracker information
self
.
_listen_port
=
listen_port
self
.
_tracker_addr
=
tracker_addr
self
.
_tracker_conn
=
None
self
.
_tracker_pending_puts
=
[]
self
.
_key_set
=
set
()
if
tracker_addr
:
logging
.
info
(
"Tracker address:
%
s"
,
str
(
tracker_addr
))
logging
.
info
(
"RPCProxy: Websock port bind to
%
d"
,
web_port
)
logging
.
info
(
"RPCProxy: Websock port bind to
%
d"
,
web_port
)
def
_on_event
(
self
,
_
):
def
_on_event
(
self
,
_
):
...
@@ -247,21 +262,85 @@ class ProxyServerHandler(object):
...
@@ -247,21 +262,85 @@ class ProxyServerHandler(object):
lhs
.
forward_proxy
=
rhs
lhs
.
forward_proxy
=
rhs
rhs
.
forward_proxy
=
lhs
rhs
.
forward_proxy
=
lhs
lhs
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_SUCCESS
))
lhs
.
send_data
(
struct
.
pack
(
'@i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
struct
.
pack
(
'@i'
,
len
(
rhs
.
rpc_key
)))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
lhs
.
send_data
(
rhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_SUCCESS
))
rhs
.
send_data
(
struct
.
pack
(
'@i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
struct
.
pack
(
'@i'
,
len
(
lhs
.
rpc_key
)))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
rhs
.
send_data
(
lhs
.
rpc_key
.
encode
(
"utf-8"
))
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
logging
.
info
(
"Pairup connect
%
s and
%
s"
,
lhs
.
name
(),
rhs
.
name
())
def
handler_ready
(
self
,
handler
):
def
_update_tracker
(
self
):
"""Report handler to be ready."""
"""Update information on tracker."""
logging
.
info
(
"Handler ready
%
s"
,
handler
.
name
())
try
:
key
=
handler
.
rpc_key
[
7
:]
if
self
.
_tracker_conn
is
None
:
self
.
_tracker_conn
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
_tracker_conn
.
connect
(
self
.
_tracker_addr
)
self
.
_tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
self
.
_tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
self
.
loop
.
stop
()
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
self
.
_tracker_addr
))
# just connect to tracker, need to update all keys
self
.
_tracker_pending_puts
=
self
.
_server_pool
.
keys
()
need_update_info
=
False
# report new connections
for
key
in
self
.
_tracker_pending_puts
:
rpc_key
,
match_key
=
key
.
split
(
":"
)
base
.
sendjson
(
self
.
_tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
self
.
_listen_port
,
":"
+
match_key
)])
assert
base
.
recvjson
(
self
.
_tracker_conn
)
==
TrackerCode
.
SUCCESS
if
rpc_key
not
in
self
.
_key_set
:
self
.
_key_set
.
add
(
rpc_key
)
need_update_info
=
True
if
need_update_info
:
keylist
=
"["
+
","
.
join
(
self
.
_key_set
)
+
"]"
cinfo
=
{
"key"
:
"server:proxy"
+
keylist
}
base
.
sendjson
(
self
.
_tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
self
.
_tracker_conn
)
==
TrackerCode
.
SUCCESS
self
.
_tracker_pending_puts
=
[]
except
(
socket
.
error
,
IOError
)
as
err
:
retry_period
=
5
logging
.
info
(
"Lost tracker connection:
%
s, try reconnect in
%
g sec"
,
str
(
err
),
retry_period
)
self
.
_tracker_conn
.
close
()
self
.
_tracker_conn
=
None
new_pool
=
{}
keyset
=
set
(
self
.
_server_pool
.
keys
())
# re-generate the server match key, so old information is invalidated.
for
key
,
handle
in
self
.
_server_pool
.
items
():
rpc_key
,
_
=
key
.
split
(
":"
)
key
=
base
.
random_key
(
rpc_key
+
":"
,
keyset
)
new_pool
[
key
]
=
handle
keyset
.
add
(
key
)
self
.
_server_pool
=
new_pool
def
_callback
():
self
.
_update_tracker
()
self
.
loop
.
call_later
(
retry_period
,
_callback
)
def
_handler_ready_tracker_mode
(
self
,
handler
):
"""tracker mode to handle handler ready."""
if
handler
.
rpc_key
.
startswith
(
"server:"
):
key
=
base
.
random_key
(
handler
.
match_key
+
":"
,
self
.
_server_pool
)
handler
.
match_key
=
key
self
.
_server_pool
[
key
]
=
handler
self
.
_tracker_pending_puts
.
append
(
key
)
self
.
_update_tracker
()
else
:
if
handler
.
match_key
in
self
.
_server_pool
:
self
.
_pair_up
(
self
.
_server_pool
.
pop
(
handler
.
match_key
),
handler
)
else
:
handler
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
def
_handler_ready_proxy_mode
(
self
,
handler
):
"""Normal proxy mode when handler is ready."""
if
handler
.
rpc_key
.
startswith
(
"server:"
):
if
handler
.
rpc_key
.
startswith
(
"server:"
):
pool_src
,
pool_dst
=
self
.
_client_pool
,
self
.
_server_pool
pool_src
,
pool_dst
=
self
.
_client_pool
,
self
.
_server_pool
timeout
=
self
.
timeout_server
timeout
=
self
.
timeout_server
...
@@ -269,6 +348,7 @@ class ProxyServerHandler(object):
...
@@ -269,6 +348,7 @@ class ProxyServerHandler(object):
pool_src
,
pool_dst
=
self
.
_server_pool
,
self
.
_client_pool
pool_src
,
pool_dst
=
self
.
_server_pool
,
self
.
_client_pool
timeout
=
self
.
timeout_client
timeout
=
self
.
timeout_client
key
=
handler
.
match_key
if
key
in
pool_src
:
if
key
in
pool_src
:
self
.
_pair_up
(
pool_src
.
pop
(
key
),
handler
)
self
.
_pair_up
(
pool_src
.
pop
(
key
),
handler
)
return
return
...
@@ -280,29 +360,41 @@ class ProxyServerHandler(object):
...
@@ -280,29 +360,41 @@ class ProxyServerHandler(object):
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
logging
.
info
(
"Timeout client connection
%
s, cannot find match key=
%
s"
,
handler
.
name
(),
key
)
handler
.
name
(),
key
)
pool_dst
.
pop
(
key
)
pool_dst
.
pop
(
key
)
handler
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_MISMATCH
))
handler
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_MISMATCH
))
handler
.
signal_close
()
handler
.
signal_close
()
self
.
loop
.
call_later
(
timeout
,
cleanup
)
self
.
loop
.
call_later
(
timeout
,
cleanup
)
else
:
else
:
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
logging
.
info
(
"Duplicate connection with same key=
%
s"
,
key
)
handler
.
send_data
(
struct
.
pack
(
'@i'
,
RPC_CODE_DUPLICATE
))
handler
.
send_data
(
struct
.
pack
(
'@i'
,
base
.
RPC_CODE_DUPLICATE
))
handler
.
signal_close
()
handler
.
signal_close
()
def
handler_ready
(
self
,
handler
):
"""Report handler to be ready."""
logging
.
info
(
"Handler ready
%
s"
,
handler
.
name
())
if
self
.
_tracker_addr
:
self
.
_handler_ready_tracker_mode
(
handler
)
else
:
self
.
_handler_ready_proxy_mode
(
handler
)
def
run
(
self
):
def
run
(
self
):
"""Run the proxy server"""
"""Run the proxy server"""
ioloop
.
IOLoop
.
current
()
.
start
()
ioloop
.
IOLoop
.
current
()
.
start
()
def
_proxy_server
(
listen_sock
,
def
_proxy_server
(
listen_sock
,
listen_port
,
web_port
,
web_port
,
timeout_client
,
timeout_client
,
timeout_server
,
timeout_server
,
tracker_addr
,
index_page
,
index_page
,
resource_files
):
resource_files
):
handler
=
ProxyServerHandler
(
listen_sock
,
handler
=
ProxyServerHandler
(
listen_sock
,
listen_port
,
web_port
,
web_port
,
timeout_client
,
timeout_client
,
timeout_server
,
timeout_server
,
tracker_addr
,
index_page
,
index_page
,
resource_files
)
resource_files
)
handler
.
run
()
handler
.
run
()
...
@@ -346,6 +438,7 @@ class Proxy(object):
...
@@ -346,6 +438,7 @@ class Proxy(object):
web_port
=
0
,
web_port
=
0
,
timeout_client
=
600
,
timeout_client
=
600
,
timeout_server
=
600
,
timeout_server
=
600
,
tracker_addr
=
None
,
index_page
=
None
,
index_page
=
None
,
resource_files
=
None
):
resource_files
=
None
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
...
@@ -365,10 +458,12 @@ class Proxy(object):
...
@@ -365,10 +458,12 @@ class Proxy(object):
logging
.
info
(
"RPCProxy: client port bind to
%
s:
%
d"
,
host
,
self
.
port
)
logging
.
info
(
"RPCProxy: client port bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
sock
.
listen
(
1
)
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
_proxy_server
,
args
=
(
sock
,
web_port
,
target
=
_proxy_server
,
timeout_client
,
timeout_server
,
args
=
(
sock
,
self
.
port
,
web_port
,
index_page
,
resource_files
))
timeout_client
,
timeout_server
,
tracker_addr
,
index_page
,
resource_files
))
self
.
proc
.
start
()
self
.
proc
.
start
()
sock
.
close
()
self
.
host
=
host
self
.
host
=
host
def
terminate
(
self
):
def
terminate
(
self
):
...
@@ -408,18 +503,18 @@ def websocket_proxy_server(url, key=""):
...
@@ -408,18 +503,18 @@ def websocket_proxy_server(url, key=""):
on_message
=
create_on_message
(
conn
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
()
temp
=
_server_env
()
# Start connecton
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'@i'
,
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'@i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
key
=
"server:"
+
key
conn
.
write_message
(
struct
.
pack
(
'@i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'@i'
,
len
(
key
)),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
conn
.
write_message
(
key
.
encode
(
"utf-8"
),
binary
=
True
)
msg
=
yield
conn
.
read_message
()
msg
=
yield
conn
.
read_message
()
assert
len
(
msg
)
>=
4
assert
len
(
msg
)
>=
4
magic
=
struct
.
unpack
(
'@i'
,
msg
[:
4
])[
0
]
magic
=
struct
.
unpack
(
'@i'
,
msg
[:
4
])[
0
]
if
magic
==
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
RPC_CODE_SUCCESS
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
url
)
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
url
)
msg
=
msg
[
4
:]
msg
=
msg
[
4
:]
...
...
python/tvm/contrib/rpc/server.py
View file @
3d67ea17
...
@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
"""
"""
# Report resource to tracker
# Report resource to tracker
if
tracker_conn
:
if
tracker_conn
:
matchkey
=
":"
+
base
.
random_key
(
)
matchkey
=
base
.
random_key
(
":"
)
base
.
sendjson
(
tracker_conn
,
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
...
@@ -130,19 +130,20 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -130,19 +130,20 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# Server logic
# Server logic
tracker_conn
=
None
tracker_conn
=
None
while
True
:
while
True
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
try
:
try
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
# step 2: wait for in-coming connections
# step 2: wait for in-coming connections
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
except
(
socket
.
error
,
IOError
):
except
(
socket
.
error
,
IOError
):
...
@@ -161,40 +162,49 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -161,40 +162,49 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# wait until server process finish or timeout
# wait until server process finish or timeout
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
server_proc
.
is_alive
():
if
server_proc
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
logging
.
info
(
"
RPCServer:
Timeout in RPC session, kill.."
)
server_proc
.
terminate
()
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
):
def
_connect_proxy_loop
(
addr
,
key
):
key
=
"server:"
+
key
key
=
"server:"
+
key
retry_count
=
0
max_retry
=
5
retry_period
=
5
while
True
:
while
True
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
try
:
sock
.
connect
(
addr
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
process
.
deamon
=
True
process
.
deamon
=
True
process
.
start
()
process
.
start
()
sock
.
close
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
if
process
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
logging
.
info
(
"RPCProxyServer: Timeout in RPC session, kill.."
)
process
.
terminate
()
process
.
terminate
()
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
logging
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
def
_popen
(
cmd
):
def
_popen
(
cmd
):
proc
=
subprocess
.
Popen
(
cmd
,
proc
=
subprocess
.
Popen
(
cmd
,
...
...
python/tvm/contrib/rpc/tracker.py
View file @
3d67ea17
...
@@ -317,7 +317,7 @@ class Tracker(object):
...
@@ -317,7 +317,7 @@ class Tracker(object):
port_end
=
9199
):
port_end
=
9199
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
port
=
None
self
.
stop_key
=
base
.
random_key
()
self
.
stop_key
=
base
.
random_key
(
"tracker"
)
for
my_port
in
range
(
port
,
port_end
):
for
my_port
in
range
(
port
,
port_end
):
try
:
try
:
sock
.
bind
((
host
,
my_port
))
sock
.
bind
((
host
,
my_port
))
...
...
python/tvm/exec/rpc_proxy.py
View file @
3d67ea17
...
@@ -29,19 +29,35 @@ def main():
...
@@ -29,19 +29,35 @@ def main():
help
=
'the hostname of the server'
)
help
=
'the hostname of the server'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9090
,
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9090
,
help
=
'The port of the PRC'
)
help
=
'The port of the PRC'
)
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
9190
,
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
8888
,
help
=
'The port of the http/websocket server'
)
help
=
'The port of the http/websocket server'
)
parser
.
add_argument
(
'--example-rpc'
,
type
=
bool
,
default
=
False
,
parser
.
add_argument
(
'--example-rpc'
,
type
=
bool
,
default
=
False
,
help
=
'Whether to switch on example rpc mode'
)
help
=
'Whether to switch on example rpc mode'
)
parser
.
add_argument
(
'--tracker'
,
type
=
str
,
default
=
""
,
help
=
"Report to RPC tracker"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
if
args
.
tracker
:
url
,
port
=
args
.
tracker
.
split
(
":"
)
port
=
int
(
port
)
tracker_addr
=
(
url
,
port
)
else
:
tracker_addr
=
None
if
args
.
example_rpc
:
if
args
.
example_rpc
:
index
,
js_files
=
find_example_resource
()
index
,
js_files
=
find_example_resource
()
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
prox
=
Proxy
(
args
.
host
,
web_port
=
args
.
web_port
,
index_page
=
index
,
port
=
args
.
port
,
resource_files
=
js_files
)
web_port
=
args
.
web_port
,
index_page
=
index
,
resource_files
=
js_files
,
tracker_addr
=
tracker_addr
)
else
:
else
:
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
)
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
tracker_addr
=
tracker_addr
)
prox
.
proc
.
join
()
prox
.
proc
.
join
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/python/contrib/test_rpc_tracker.py
View file @
3d67ea17
...
@@ -8,7 +8,7 @@ from tvm.contrib import rpc
...
@@ -8,7 +8,7 @@ from tvm.contrib import rpc
def
check_server_drop
():
def
check_server_drop
():
"""test when server drops"""
"""test when server drops"""
try
:
try
:
from
tvm.contrib.rpc
import
tracker
,
base
from
tvm.contrib.rpc
import
tracker
,
proxy
,
base
from
tvm.contrib.rpc.base
import
TrackerCode
from
tvm.contrib.rpc.base
import
TrackerCode
@tvm.register_func
(
"rpc.test2.addone"
)
@tvm.register_func
(
"rpc.test2.addone"
)
...
@@ -20,21 +20,24 @@ def check_server_drop():
...
@@ -20,21 +20,24 @@ def check_server_drop():
base
.
recvjson
(
tclient
.
_sock
)
base
.
recvjson
(
tclient
.
_sock
)
tserver
=
tracker
.
Tracker
(
"localhost"
,
8888
)
tserver
=
tracker
.
Tracker
(
"localhost"
,
8888
)
tproxy
=
proxy
.
Proxy
(
"localhost"
,
8881
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
))
tclient
=
rpc
.
connect_tracker
(
"localhost"
,
tserver
.
port
)
tclient
=
rpc
.
connect_tracker
(
"localhost"
,
tserver
.
port
)
server1
=
rpc
.
Server
(
server1
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
key
=
"xyz"
)
key
=
"xyz"
)
server2
=
rpc
.
Server
(
server2
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
key
=
"xyz"
)
key
=
"xyz"
)
server3
=
rpc
.
Server
(
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
key
=
"xyz1"
)
# Fault tolerence to stale worker value
# Fault tolerence to stale worker value
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abc"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abc"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abcxxx"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abcxxx"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server2
.
port
,
"abcxxx11"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
tproxy
.
port
,
"abcxxx11"
)])
# Fault tolerence server timeout
# Fault tolerence server timeout
def
check_timeout
(
timeout
,
sleeptime
):
def
check_timeout
(
timeout
,
sleeptime
):
...
@@ -60,8 +63,11 @@ def check_server_drop():
...
@@ -60,8 +63,11 @@ def check_server_drop():
check_timeout
(
0.01
,
0.1
)
check_timeout
(
0.01
,
0.1
)
check_timeout
(
2
,
0
)
check_timeout
(
2
,
0
)
tserver
.
terminate
()
server2
.
terminate
()
server1
.
terminate
()
server3
.
terminate
()
tproxy
.
terminate
()
except
ImportError
:
except
ImportError
:
print
(
"Skip because tornado is not available"
)
print
(
"Skip because tornado is not available"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment