Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3d67ea17
Commit
3d67ea17
authored
6 years ago
by
Tianqi Chen
Committed by
GitHub
6 years ago
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RPC] support tracker in proxy (#1082)
parent
79fc6672
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
109 additions
and
57 deletions
+109
-57
python/tvm/contrib/rpc/base.py
+23
-3
python/tvm/contrib/rpc/proxy.py
+0
-0
python/tvm/contrib/rpc/server.py
+51
-41
python/tvm/contrib/rpc/tracker.py
+1
-1
python/tvm/exec/rpc_proxy.py
+21
-5
tests/python/contrib/test_rpc_tracker.py
+13
-7
No files found.
python/tvm/contrib/rpc/base.py
View file @
3d67ea17
...
...
@@ -94,9 +94,29 @@ def recvjson(sock):
return
data
def
random_key
():
"""Generate a random key n"""
return
str
(
random
.
random
())
def
random_key
(
prefix
,
cmap
=
None
):
"""Generate a random key
Parameters
----------
prefix : str
The string prefix
cmap : dict
Conflict map
Returns
-------
key : str
The generated random key
"""
if
cmap
:
while
True
:
key
=
prefix
+
str
(
random
.
random
())
if
key
not
in
cmap
:
return
key
else
:
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
...
...
This diff is collapsed.
Click to expand it.
python/tvm/contrib/rpc/proxy.py
View file @
3d67ea17
This diff is collapsed.
Click to expand it.
python/tvm/contrib/rpc/server.py
View file @
3d67ea17
...
...
@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
"""
# Report resource to tracker
if
tracker_conn
:
matchkey
=
":"
+
base
.
random_key
(
)
matchkey
=
base
.
random_key
(
":"
)
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
...
...
@@ -130,19 +130,20 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# Server logic
tracker_conn
=
None
while
True
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
try
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
# step 2: wait for in-coming connections
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
except
(
socket
.
error
,
IOError
):
...
...
@@ -161,40 +162,49 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# wait until server process finish or timeout
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
server_proc
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
logging
.
info
(
"
RPCServer:
Timeout in RPC session, kill.."
)
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
):
key
=
"server:"
+
key
retry_count
=
0
max_retry
=
5
retry_period
=
5
while
True
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
process
.
deamon
=
True
process
.
start
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
process
.
terminate
()
try
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
process
.
deamon
=
True
process
.
start
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
logging
.
info
(
"RPCProxyServer: Timeout in RPC session, kill.."
)
process
.
terminate
()
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
logging
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
def
_popen
(
cmd
):
proc
=
subprocess
.
Popen
(
cmd
,
...
...
This diff is collapsed.
Click to expand it.
python/tvm/contrib/rpc/tracker.py
View file @
3d67ea17
...
...
@@ -317,7 +317,7 @@ class Tracker(object):
port_end
=
9199
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
stop_key
=
base
.
random_key
()
self
.
stop_key
=
base
.
random_key
(
"tracker"
)
for
my_port
in
range
(
port
,
port_end
):
try
:
sock
.
bind
((
host
,
my_port
))
...
...
This diff is collapsed.
Click to expand it.
python/tvm/exec/rpc_proxy.py
View file @
3d67ea17
...
...
@@ -29,19 +29,35 @@ def main():
help
=
'the hostname of the server'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9090
,
help
=
'The port of the PRC'
)
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
9190
,
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
8888
,
help
=
'The port of the http/websocket server'
)
parser
.
add_argument
(
'--example-rpc'
,
type
=
bool
,
default
=
False
,
help
=
'Whether to switch on example rpc mode'
)
parser
.
add_argument
(
'--tracker'
,
type
=
str
,
default
=
""
,
help
=
"Report to RPC tracker"
)
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
if
args
.
tracker
:
url
,
port
=
args
.
tracker
.
split
(
":"
)
port
=
int
(
port
)
tracker_addr
=
(
url
,
port
)
else
:
tracker_addr
=
None
if
args
.
example_rpc
:
index
,
js_files
=
find_example_resource
()
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
index_page
=
index
,
resource_files
=
js_files
)
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
index_page
=
index
,
resource_files
=
js_files
,
tracker_addr
=
tracker_addr
)
else
:
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
)
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
tracker_addr
=
tracker_addr
)
prox
.
proc
.
join
()
if
__name__
==
"__main__"
:
...
...
This diff is collapsed.
Click to expand it.
tests/python/contrib/test_rpc_tracker.py
View file @
3d67ea17
...
...
@@ -8,7 +8,7 @@ from tvm.contrib import rpc
def
check_server_drop
():
"""test when server drops"""
try
:
from
tvm.contrib.rpc
import
tracker
,
base
from
tvm.contrib.rpc
import
tracker
,
proxy
,
base
from
tvm.contrib.rpc.base
import
TrackerCode
@tvm.register_func
(
"rpc.test2.addone"
)
...
...
@@ -20,21 +20,24 @@ def check_server_drop():
base
.
recvjson
(
tclient
.
_sock
)
tserver
=
tracker
.
Tracker
(
"localhost"
,
8888
)
tproxy
=
proxy
.
Proxy
(
"localhost"
,
8881
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
))
tclient
=
rpc
.
connect_tracker
(
"localhost"
,
tserver
.
port
)
server1
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
key
=
"xyz"
)
server2
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
key
=
"xyz"
)
server3
=
rpc
.
Server
(
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
key
=
"xyz1"
)
# Fault tolerence to stale worker value
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abc"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abcxxx"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server2
.
port
,
"abcxxx11"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
tproxy
.
port
,
"abcxxx11"
)])
# Fault tolerence server timeout
def
check_timeout
(
timeout
,
sleeptime
):
...
...
@@ -60,8 +63,11 @@ def check_server_drop():
check_timeout
(
0.01
,
0.1
)
check_timeout
(
2
,
0
)
tserver
.
terminate
()
server2
.
terminate
()
server1
.
terminate
()
server3
.
terminate
()
tproxy
.
terminate
()
except
ImportError
:
print
(
"Skip because tornado is not available"
)
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment