Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
3d67ea17
Commit
3d67ea17
authored
Apr 05, 2018
by
Tianqi Chen
Committed by
GitHub
Apr 05, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[RPC] support tracker in proxy (#1082)
parent
79fc6672
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
109 additions
and
57 deletions
+109
-57
python/tvm/contrib/rpc/base.py
+23
-3
python/tvm/contrib/rpc/proxy.py
+0
-0
python/tvm/contrib/rpc/server.py
+51
-41
python/tvm/contrib/rpc/tracker.py
+1
-1
python/tvm/exec/rpc_proxy.py
+21
-5
tests/python/contrib/test_rpc_tracker.py
+13
-7
No files found.
python/tvm/contrib/rpc/base.py
View file @
3d67ea17
...
@@ -94,9 +94,29 @@ def recvjson(sock):
...
@@ -94,9 +94,29 @@ def recvjson(sock):
return
data
return
data
def
random_key
():
def
random_key
(
prefix
,
cmap
=
None
):
"""Generate a random key n"""
"""Generate a random key
return
str
(
random
.
random
())
Parameters
----------
prefix : str
The string prefix
cmap : dict
Conflict map
Returns
-------
key : str
The generated random key
"""
if
cmap
:
while
True
:
key
=
prefix
+
str
(
random
.
random
())
if
key
not
in
cmap
:
return
key
else
:
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
...
...
python/tvm/contrib/rpc/proxy.py
View file @
3d67ea17
This diff is collapsed.
Click to expand it.
python/tvm/contrib/rpc/server.py
View file @
3d67ea17
...
@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -91,7 +91,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
"""
"""
# Report resource to tracker
# Report resource to tracker
if
tracker_conn
:
if
tracker_conn
:
matchkey
=
":"
+
base
.
random_key
(
)
matchkey
=
base
.
random_key
(
":"
)
base
.
sendjson
(
tracker_conn
,
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
)])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
...
@@ -130,19 +130,20 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -130,19 +130,20 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# Server logic
# Server logic
tracker_conn
=
None
tracker_conn
=
None
while
True
:
while
True
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
try
:
try
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
raise
RuntimeError
(
"
%
s is not RPC Tracker"
%
str
(
tracker_addr
))
# report status of current queue
cinfo
=
{
"key"
:
"server:"
+
rpc_key
}
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
UPDATE_INFO
,
cinfo
])
assert
base
.
recvjson
(
tracker_conn
)
==
TrackerCode
.
SUCCESS
# step 2: wait for in-coming connections
# step 2: wait for in-coming connections
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
conn
,
addr
,
opts
=
_accept_conn
(
sock
,
tracker_conn
)
except
(
socket
.
error
,
IOError
):
except
(
socket
.
error
,
IOError
):
...
@@ -161,40 +162,49 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
...
@@ -161,40 +162,49 @@ def _listen_loop(sock, port, rpc_key, tracker_addr):
# wait until server process finish or timeout
# wait until server process finish or timeout
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
server_proc
.
is_alive
():
if
server_proc
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
logging
.
info
(
"
RPCServer:
Timeout in RPC session, kill.."
)
server_proc
.
terminate
()
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
):
def
_connect_proxy_loop
(
addr
,
key
):
key
=
"server:"
+
key
key
=
"server:"
+
key
retry_count
=
0
max_retry
=
5
retry_period
=
5
while
True
:
while
True
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
try
:
sock
.
connect
(
addr
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
connect
(
addr
)
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
base
.
RPC_MAGIC
))
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
sock
.
sendall
(
struct
.
pack
(
"@i"
,
len
(
key
)))
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
sock
.
sendall
(
key
.
encode
(
"utf-8"
))
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
magic
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
logging
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
keylen
=
struct
.
unpack
(
"@i"
,
base
.
recvall
(
sock
,
4
))[
0
]
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
logging
.
info
(
"RPCProxy connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
target
=
_serve_loop
,
args
=
(
sock
,
addr
))
process
.
deamon
=
True
process
.
deamon
=
True
process
.
start
()
process
.
start
()
sock
.
close
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
if
process
.
is_alive
():
logging
.
info
(
"Timeout in RPC session, kill.."
)
logging
.
info
(
"RPCProxyServer: Timeout in RPC session, kill.."
)
process
.
terminate
()
process
.
terminate
()
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
logging
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
def
_popen
(
cmd
):
def
_popen
(
cmd
):
proc
=
subprocess
.
Popen
(
cmd
,
proc
=
subprocess
.
Popen
(
cmd
,
...
...
python/tvm/contrib/rpc/tracker.py
View file @
3d67ea17
...
@@ -317,7 +317,7 @@ class Tracker(object):
...
@@ -317,7 +317,7 @@ class Tracker(object):
port_end
=
9199
):
port_end
=
9199
):
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
port
=
None
self
.
stop_key
=
base
.
random_key
()
self
.
stop_key
=
base
.
random_key
(
"tracker"
)
for
my_port
in
range
(
port
,
port_end
):
for
my_port
in
range
(
port
,
port_end
):
try
:
try
:
sock
.
bind
((
host
,
my_port
))
sock
.
bind
((
host
,
my_port
))
...
...
python/tvm/exec/rpc_proxy.py
View file @
3d67ea17
...
@@ -29,19 +29,35 @@ def main():
...
@@ -29,19 +29,35 @@ def main():
help
=
'the hostname of the server'
)
help
=
'the hostname of the server'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9090
,
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9090
,
help
=
'The port of the PRC'
)
help
=
'The port of the PRC'
)
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
9190
,
parser
.
add_argument
(
'--web-port'
,
type
=
int
,
default
=
8888
,
help
=
'The port of the http/websocket server'
)
help
=
'The port of the http/websocket server'
)
parser
.
add_argument
(
'--example-rpc'
,
type
=
bool
,
default
=
False
,
parser
.
add_argument
(
'--example-rpc'
,
type
=
bool
,
default
=
False
,
help
=
'Whether to switch on example rpc mode'
)
help
=
'Whether to switch on example rpc mode'
)
parser
.
add_argument
(
'--tracker'
,
type
=
str
,
default
=
""
,
help
=
"Report to RPC tracker"
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
if
args
.
tracker
:
url
,
port
=
args
.
tracker
.
split
(
":"
)
port
=
int
(
port
)
tracker_addr
=
(
url
,
port
)
else
:
tracker_addr
=
None
if
args
.
example_rpc
:
if
args
.
example_rpc
:
index
,
js_files
=
find_example_resource
()
index
,
js_files
=
find_example_resource
()
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
prox
=
Proxy
(
args
.
host
,
web_port
=
args
.
web_port
,
index_page
=
index
,
port
=
args
.
port
,
resource_files
=
js_files
)
web_port
=
args
.
web_port
,
index_page
=
index
,
resource_files
=
js_files
,
tracker_addr
=
tracker_addr
)
else
:
else
:
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
)
prox
=
Proxy
(
args
.
host
,
port
=
args
.
port
,
web_port
=
args
.
web_port
,
tracker_addr
=
tracker_addr
)
prox
.
proc
.
join
()
prox
.
proc
.
join
()
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
tests/python/contrib/test_rpc_tracker.py
View file @
3d67ea17
...
@@ -8,7 +8,7 @@ from tvm.contrib import rpc
...
@@ -8,7 +8,7 @@ from tvm.contrib import rpc
def
check_server_drop
():
def
check_server_drop
():
"""test when server drops"""
"""test when server drops"""
try
:
try
:
from
tvm.contrib.rpc
import
tracker
,
base
from
tvm.contrib.rpc
import
tracker
,
proxy
,
base
from
tvm.contrib.rpc.base
import
TrackerCode
from
tvm.contrib.rpc.base
import
TrackerCode
@tvm.register_func
(
"rpc.test2.addone"
)
@tvm.register_func
(
"rpc.test2.addone"
)
...
@@ -20,21 +20,24 @@ def check_server_drop():
...
@@ -20,21 +20,24 @@ def check_server_drop():
base
.
recvjson
(
tclient
.
_sock
)
base
.
recvjson
(
tclient
.
_sock
)
tserver
=
tracker
.
Tracker
(
"localhost"
,
8888
)
tserver
=
tracker
.
Tracker
(
"localhost"
,
8888
)
tproxy
=
proxy
.
Proxy
(
"localhost"
,
8881
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
))
tclient
=
rpc
.
connect_tracker
(
"localhost"
,
tserver
.
port
)
tclient
=
rpc
.
connect_tracker
(
"localhost"
,
tserver
.
port
)
server1
=
rpc
.
Server
(
server1
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
"localhost"
,
port
=
9099
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
key
=
"xyz"
)
key
=
"xyz"
)
server2
=
rpc
.
Server
(
server2
=
rpc
.
Server
(
"localhost"
,
port
=
9099
,
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
tracker_addr
=
(
"localhost"
,
tserver
.
port
),
key
=
"xyz"
)
key
=
"xyz"
)
server3
=
rpc
.
Server
(
"localhost"
,
tproxy
.
port
,
is_proxy
=
True
,
key
=
"xyz1"
)
# Fault tolerence to stale worker value
# Fault tolerence to stale worker value
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abc"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abc"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abcxxx"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server1
.
port
,
"abcxxx"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
server2
.
port
,
"abcxxx11"
)])
_put
(
tclient
,
[
TrackerCode
.
PUT
,
"xyz"
,
(
tproxy
.
port
,
"abcxxx11"
)])
# Fault tolerence server timeout
# Fault tolerence server timeout
def
check_timeout
(
timeout
,
sleeptime
):
def
check_timeout
(
timeout
,
sleeptime
):
...
@@ -60,8 +63,11 @@ def check_server_drop():
...
@@ -60,8 +63,11 @@ def check_server_drop():
check_timeout
(
0.01
,
0.1
)
check_timeout
(
0.01
,
0.1
)
check_timeout
(
2
,
0
)
check_timeout
(
2
,
0
)
tserver
.
terminate
()
server2
.
terminate
()
server1
.
terminate
()
server3
.
terminate
()
tproxy
.
terminate
()
except
ImportError
:
except
ImportError
:
print
(
"Skip because tornado is not available"
)
print
(
"Skip because tornado is not available"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment