Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d0eb2d3d
Commit
d0eb2d3d
authored
Jun 13, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Jun 12, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add silent mode to rpc server and rpc tracker (#1268)
parent
558cf098
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
38 deletions
+91
-38
python/tvm/contrib/rpc/base.py
+5
-1
python/tvm/contrib/rpc/proxy.py
+1
-1
python/tvm/contrib/rpc/server.py
+57
-25
python/tvm/contrib/rpc/tracker.py
+11
-4
python/tvm/exec/rpc_server.py
+7
-3
python/tvm/exec/rpc_tracker.py
+10
-3
src/runtime/rpc/rpc_server_env.cc
+0
-1
No files found.
python/tvm/contrib/rpc/base.py
View file @
d0eb2d3d
...
...
@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None):
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
,
silent
=
False
):
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
...
...
@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
retry_period : float
Number of seconds before we retry again.
silent: bool
whether run in silent mode
"""
tstart
=
time
.
time
()
while
True
:
...
...
@@ -149,6 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
if
period
>
timeout
:
raise
RuntimeError
(
"Failed to connect to server
%
s"
%
str
(
addr
))
if
not
silent
:
logging
.
info
(
"Cannot connect to tracker
%
s, retry in
%
g secs..."
,
str
(
addr
),
retry_period
)
time
.
sleep
(
retry_period
)
...
...
python/tvm/contrib/rpc/proxy.py
View file @
d0eb2d3d
...
...
@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def
_connect
(
key
):
conn
=
yield
websocket
.
websocket_connect
(
url
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
)
temp
=
_server_env
(
None
,
None
)
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'<i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
...
...
python/tvm/contrib/rpc/server.py
View file @
d0eb2d3d
...
...
@@ -19,6 +19,7 @@ import logging
import
multiprocessing
import
subprocess
import
time
import
sys
from
..._ffi.function
import
register_func
from
..._ffi.base
import
py_str
...
...
@@ -28,9 +29,12 @@ from .. import util
from
.
import
base
from
.
base
import
TrackerCode
def
_server_env
(
load_library
):
def
_server_env
(
load_library
,
logger
):
"""Server environment function return temp dir"""
temp
=
util
.
tempdir
()
if
logger
is
None
:
logger
=
logging
.
getLogger
()
# pylint: disable=unused-variable
@register_func
(
"tvm.contrib.rpc.server.workpath"
)
def
get_workpath
(
path
):
...
...
@@ -41,7 +45,7 @@ def _server_env(load_library):
"""Load module from remote side."""
path
=
temp
.
relpath
(
file_name
)
m
=
_load_module
(
path
)
logg
ing
.
info
(
"load_module
%
s"
,
path
)
logg
er
.
info
(
"load_module
%
s"
,
path
)
return
m
libs
=
[]
...
...
@@ -49,18 +53,21 @@ def _server_env(load_library):
for
file_name
in
load_library
:
file_name
=
find_lib_path
(
file_name
)[
0
]
libs
.
append
(
ctypes
.
CDLL
(
file_name
,
ctypes
.
RTLD_GLOBAL
))
logg
ing
.
info
(
"Load additional library
%
s"
,
file_name
)
logg
er
.
info
(
"Load additional library
%
s"
,
file_name
)
temp
.
libs
=
libs
return
temp
def
_serve_loop
(
sock
,
addr
,
load_library
):
def
_serve_loop
(
sock
,
addr
,
load_library
,
silent
):
"""Server loop"""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
sockfd
=
sock
.
fileno
()
temp
=
_server_env
(
load_library
)
temp
=
_server_env
(
load_library
,
logger
)
base
.
_ServerLoop
(
sockfd
)
temp
.
remove
()
logg
ing
.
info
(
"Finish serving
%
s"
,
addr
)
logg
er
.
info
(
"Finish serving
%
s"
,
addr
)
def
_parse_server_opt
(
opts
):
...
...
@@ -71,8 +78,12 @@ def _parse_server_opt(opts):
ret
[
"timeout"
]
=
float
(
kv
[
9
:])
return
ret
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
):
"""Lisenting loop of the server master."""
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
,
silent
):
"""Listening loop of the server master."""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
"""Accept connection from the other places.
...
...
@@ -115,7 +126,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
unmatch_period_count
=
0
# regenerate match key if key is acquired but not used for a while
if
unmatch_period_count
*
ping_period
>
unmatch_timeout
+
ping_period
:
logg
ing
.
info
(
"RPCServer:
no incoming connections, regenerate key ..."
)
logg
er
.
info
(
"
no incoming connections, regenerate key ..."
)
matchkey
=
base
.
random_key
(
rpc_key
+
":"
,
old_keyset
)
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
),
...
...
@@ -136,7 +147,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
logg
ing
.
info
(
"RPCServer:
mismatch key from
%
s"
,
addr
)
logg
er
.
info
(
"
mismatch key from
%
s"
,
addr
)
continue
else
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_SUCCESS
))
...
...
@@ -150,7 +161,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
try
:
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
,
silent
=
silent
)
tracker_conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
...
...
@@ -169,10 +180,16 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
tracker_conn
.
close
()
tracker_conn
=
None
continue
except
RuntimeError
as
exc
:
if
silent
:
return
else
:
raise
exc
# step 3: serving
logging
.
info
(
"RPCServer: connection from
%
s"
,
addr
)
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
,
load_library
))
logger
.
info
(
"connection from
%
s"
,
addr
)
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
,
load_library
,
silent
))
server_proc
.
deamon
=
True
server_proc
.
start
()
# close from our side.
...
...
@@ -180,11 +197,14 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
# wait until server process finish or timeout
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
server_proc
.
is_alive
():
logg
ing
.
info
(
"RPCServer:
Timeout in RPC session, kill.."
)
logg
er
.
info
(
"
Timeout in RPC session, kill.."
)
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
,
load_library
):
def
_connect_proxy_loop
(
addr
,
key
,
load_library
,
silent
):
logger
=
logging
.
getLogger
(
"RPCProxy"
)
if
silent
:
logger
.
disabled
=
True
key
=
"server:"
+
key
retry_count
=
0
max_retry
=
5
...
...
@@ -200,26 +220,26 @@ def _connect_proxy_loop(addr, key, load_library):
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logg
ing
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
logg
er
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logg
ing
.
info
(
"RPCProxy
connected to
%
s"
,
str
(
addr
))
logg
er
.
info
(
"
connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
))
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
,
silent
))
process
.
deamon
=
True
process
.
start
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
logg
ing
.
info
(
"RPCProxyServer:
Timeout in RPC session, kill.."
)
logg
er
.
info
(
"
Timeout in RPC session, kill.."
)
process
.
terminate
()
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
logg
ing
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
logg
er
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
...
...
@@ -264,6 +284,9 @@ class Server(object):
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
silent: bool, optional
Whether run this server in silent mode.
key : str, optional
The key used to identify the server in Proxy connection.
...
...
@@ -276,6 +299,7 @@ class Server(object):
port_end
=
9199
,
is_proxy
=
False
,
use_popen
=
False
,
silent
=
False
,
tracker_addr
=
None
,
key
=
""
,
load_library
=
None
,
...
...
@@ -290,8 +314,12 @@ class Server(object):
self
.
libs
=
[]
self
.
custom_addr
=
custom_addr
self
.
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
self
.
logger
.
disabled
=
True
if
use_popen
:
cmd
=
[
"python"
,
cmd
=
[
sys
.
executable
,
"-m"
,
"tvm.exec.rpc_server"
,
"--host=
%
s"
%
host
,
"--port=
%
s"
%
port
]
...
...
@@ -303,11 +331,14 @@ class Server(object):
cmd
+=
[
"--load-library"
,
load_library
]
if
custom_addr
:
cmd
+=
[
"--custom-addr"
,
custom_addr
]
if
silent
:
cmd
+=
[
"--silent"
]
self
.
proc
=
multiprocessing
.
Process
(
target
=
subprocess
.
check_call
,
args
=
(
cmd
,))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
0.5
)
elif
not
is_proxy
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
...
...
@@ -323,17 +354,18 @@ class Server(object):
raise
sock_err
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
logging
.
info
(
"RPCServer:
bind to
%
s:
%
d"
,
host
,
self
.
port
)
self
.
logger
.
info
(
"
bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
self
.
sock
=
sock
self
.
proc
=
multiprocessing
.
Process
(
target
=
_listen_loop
,
args
=
(
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
,
self
.
custom_addr
))
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
,
self
.
custom_addr
,
silent
))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
else
:
self
.
proc
=
multiprocessing
.
Process
(
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
))
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
,
silent
))
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
...
...
python/tvm/contrib/rpc/tracker.py
View file @
d0eb2d3d
...
...
@@ -309,7 +309,6 @@ class TrackerServerHandler(object):
def
_tracker_server
(
listen_sock
,
stop_key
):
handler
=
TrackerServerHandler
(
listen_sock
,
stop_key
)
handler
.
run
()
logging
.
info
(
"Tracker Stop signal received, terminating..."
)
class
Tracker
(
object
):
...
...
@@ -327,11 +326,19 @@ class Tracker(object):
port_end : int, optional
The end TCP port to search
silent: bool, optional
Whether run in silent mode
"""
def
__init__
(
self
,
host
,
port
=
9190
,
port_end
=
9199
):
port_end
=
9199
,
silent
=
False
):
self
.
logger
=
logging
.
getLogger
(
"RPCTracker"
)
if
silent
:
self
.
logger
.
disabled
=
True
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
stop_key
=
base
.
random_key
(
"tracker"
)
...
...
@@ -347,7 +354,7 @@ class Tracker(object):
raise
sock_err
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
logging
.
info
(
"RPCTracker:
bind to
%
s:
%
d"
,
host
,
self
.
port
)
self
.
logger
.
info
(
"
bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
self
.
proc
=
multiprocessing
.
Process
(
target
=
_tracker_server
,
args
=
(
sock
,
self
.
stop_key
))
...
...
@@ -373,7 +380,7 @@ class Tracker(object):
self
.
_stop_tracker
()
self
.
proc
.
join
(
1
)
if
self
.
proc
.
is_alive
():
logging
.
info
(
"Terminating Tracker Server..."
)
self
.
logger
.
info
(
"Terminating Tracker Server..."
)
self
.
proc
.
terminate
()
self
.
proc
=
None
...
...
python/tvm/exec/rpc_server.py
View file @
d0eb2d3d
...
...
@@ -27,7 +27,8 @@ def main(args):
key
=
args
.
key
,
tracker_addr
=
tracker_addr
,
load_library
=
args
.
load_library
,
custom_addr
=
args
.
custom_addr
)
custom_addr
=
args
.
custom_addr
,
silent
=
args
.
silent
)
server
.
proc
.
join
()
...
...
@@ -51,6 +52,8 @@ if __name__ == "__main__":
and ROCM compilers."
)
parser
.
add_argument
(
'--custom-addr'
,
type
=
str
,
help
=
"Custom IP Address to Report to RPC Tracker"
)
parser
.
add_argument
(
'--silent'
,
action
=
'store_true'
,
help
=
"Whether run in silent mode."
)
parser
.
set_defaults
(
fork
=
True
)
args
=
parser
.
parse_args
()
...
...
@@ -62,6 +65,7 @@ if __name__ == "__main__":
)
multiprocessing
.
set_start_method
(
'spawn'
)
else
:
logging
.
info
(
"If you are running ROCM/Metal,
\
fork with cause compiler internal error. Try to launch with arg ```--no-fork```"
)
if
not
args
.
silent
:
logging
.
info
(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main
(
args
)
python/tvm/exec/rpc_tracker.py
View file @
d0eb2d3d
...
...
@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker
def
main
(
args
):
"""Main funciton"""
tracker
=
Tracker
(
args
.
host
,
port
=
args
.
port
)
tracker
=
Tracker
(
args
.
host
,
port
=
args
.
port
,
port_end
=
args
.
port_end
,
silent
=
args
.
silent
)
tracker
.
proc
.
join
()
...
...
@@ -21,10 +22,15 @@ if __name__ == "__main__":
help
=
'the hostname of the tracker'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9190
,
help
=
'The port of the PRC'
)
parser
.
add_argument
(
'--port-end'
,
type
=
int
,
default
=
9199
,
help
=
'The end search port of the PRC'
)
parser
.
add_argument
(
'--no-fork'
,
dest
=
'fork'
,
action
=
'store_false'
,
help
=
"Use spawn mode to avoid fork. This option
\
is able to avoid potential fork problems with Metal, OpenCL
\
and ROCM compilers."
)
parser
.
add_argument
(
'--silent'
,
action
=
'store_true'
,
help
=
"Whether run in silent mode."
)
parser
.
set_defaults
(
fork
=
True
)
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
...
@@ -35,6 +41,7 @@ if __name__ == "__main__":
)
multiprocessing
.
set_start_method
(
'spawn'
)
else
:
logging
.
info
(
"If you are running ROCM/Metal,
\
fork with cause compiler internal error. Try to launch with arg ```--no-fork```"
)
if
not
args
.
silent
:
logging
.
info
(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main
(
args
)
src/runtime/rpc/rpc_server_env.cc
View file @
d0eb2d3d
...
...
@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
string
file_name
=
RPCGetPath
(
args
[
0
]);
std
::
string
data
=
args
[
1
];
LOG
(
INFO
)
<<
"Upload "
<<
file_name
<<
"... nbytes="
<<
data
.
length
();
SaveBinaryToFile
(
file_name
,
data
);
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment