Skip to content
Projects
Groups
Snippets
Help
This project
Loading...
Sign in / Register
Toggle navigation
T
tic
Overview
Overview
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
wenyuanbo
tic
Commits
d0eb2d3d
Commit
d0eb2d3d
authored
Jun 13, 2018
by
Lianmin Zheng
Committed by
Tianqi Chen
Jun 12, 2018
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Add silent mode to rpc server and rpc tracker (#1268)
parent
558cf098
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
91 additions
and
38 deletions
+91
-38
python/tvm/contrib/rpc/base.py
+5
-1
python/tvm/contrib/rpc/proxy.py
+1
-1
python/tvm/contrib/rpc/server.py
+57
-25
python/tvm/contrib/rpc/tracker.py
+11
-4
python/tvm/exec/rpc_server.py
+7
-3
python/tvm/exec/rpc_tracker.py
+10
-3
src/runtime/rpc/rpc_server_env.cc
+0
-1
No files found.
python/tvm/contrib/rpc/base.py
View file @
d0eb2d3d
...
@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None):
...
@@ -120,7 +120,7 @@ def random_key(prefix, cmap=None):
return
prefix
+
str
(
random
.
random
())
return
prefix
+
str
(
random
.
random
())
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
):
def
connect_with_retry
(
addr
,
timeout
=
60
,
retry_period
=
5
,
silent
=
False
):
"""Connect to a TPC address with retry
"""Connect to a TPC address with retry
This function is only reliable to short period of server restart.
This function is only reliable to short period of server restart.
...
@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
...
@@ -135,6 +135,9 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
retry_period : float
retry_period : float
Number of seconds before we retry again.
Number of seconds before we retry again.
silent: bool
whether run in silent mode
"""
"""
tstart
=
time
.
time
()
tstart
=
time
.
time
()
while
True
:
while
True
:
...
@@ -149,6 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
...
@@ -149,6 +152,7 @@ def connect_with_retry(addr, timeout=60, retry_period=5):
if
period
>
timeout
:
if
period
>
timeout
:
raise
RuntimeError
(
raise
RuntimeError
(
"Failed to connect to server
%
s"
%
str
(
addr
))
"Failed to connect to server
%
s"
%
str
(
addr
))
if
not
silent
:
logging
.
info
(
"Cannot connect to tracker
%
s, retry in
%
g secs..."
,
logging
.
info
(
"Cannot connect to tracker
%
s, retry in
%
g secs..."
,
str
(
addr
),
retry_period
)
str
(
addr
),
retry_period
)
time
.
sleep
(
retry_period
)
time
.
sleep
(
retry_period
)
...
...
python/tvm/contrib/rpc/proxy.py
View file @
d0eb2d3d
...
@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
...
@@ -536,7 +536,7 @@ def websocket_proxy_server(url, key=""):
def
_connect
(
key
):
def
_connect
(
key
):
conn
=
yield
websocket
.
websocket_connect
(
url
)
conn
=
yield
websocket
.
websocket_connect
(
url
)
on_message
=
create_on_message
(
conn
)
on_message
=
create_on_message
(
conn
)
temp
=
_server_env
(
None
)
temp
=
_server_env
(
None
,
None
)
# Start connecton
# Start connecton
conn
.
write_message
(
struct
.
pack
(
'<i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
conn
.
write_message
(
struct
.
pack
(
'<i'
,
base
.
RPC_MAGIC
),
binary
=
True
)
key
=
"server:"
+
key
key
=
"server:"
+
key
...
...
python/tvm/contrib/rpc/server.py
View file @
d0eb2d3d
...
@@ -19,6 +19,7 @@ import logging
...
@@ -19,6 +19,7 @@ import logging
import
multiprocessing
import
multiprocessing
import
subprocess
import
subprocess
import
time
import
time
import
sys
from
..._ffi.function
import
register_func
from
..._ffi.function
import
register_func
from
..._ffi.base
import
py_str
from
..._ffi.base
import
py_str
...
@@ -28,9 +29,12 @@ from .. import util
...
@@ -28,9 +29,12 @@ from .. import util
from
.
import
base
from
.
import
base
from
.
base
import
TrackerCode
from
.
base
import
TrackerCode
def
_server_env
(
load_library
):
def
_server_env
(
load_library
,
logger
):
"""Server environment function return temp dir"""
"""Server environment function return temp dir"""
temp
=
util
.
tempdir
()
temp
=
util
.
tempdir
()
if
logger
is
None
:
logger
=
logging
.
getLogger
()
# pylint: disable=unused-variable
# pylint: disable=unused-variable
@register_func
(
"tvm.contrib.rpc.server.workpath"
)
@register_func
(
"tvm.contrib.rpc.server.workpath"
)
def
get_workpath
(
path
):
def
get_workpath
(
path
):
...
@@ -41,7 +45,7 @@ def _server_env(load_library):
...
@@ -41,7 +45,7 @@ def _server_env(load_library):
"""Load module from remote side."""
"""Load module from remote side."""
path
=
temp
.
relpath
(
file_name
)
path
=
temp
.
relpath
(
file_name
)
m
=
_load_module
(
path
)
m
=
_load_module
(
path
)
logg
ing
.
info
(
"load_module
%
s"
,
path
)
logg
er
.
info
(
"load_module
%
s"
,
path
)
return
m
return
m
libs
=
[]
libs
=
[]
...
@@ -49,18 +53,21 @@ def _server_env(load_library):
...
@@ -49,18 +53,21 @@ def _server_env(load_library):
for
file_name
in
load_library
:
for
file_name
in
load_library
:
file_name
=
find_lib_path
(
file_name
)[
0
]
file_name
=
find_lib_path
(
file_name
)[
0
]
libs
.
append
(
ctypes
.
CDLL
(
file_name
,
ctypes
.
RTLD_GLOBAL
))
libs
.
append
(
ctypes
.
CDLL
(
file_name
,
ctypes
.
RTLD_GLOBAL
))
logg
ing
.
info
(
"Load additional library
%
s"
,
file_name
)
logg
er
.
info
(
"Load additional library
%
s"
,
file_name
)
temp
.
libs
=
libs
temp
.
libs
=
libs
return
temp
return
temp
def
_serve_loop
(
sock
,
addr
,
load_library
):
def
_serve_loop
(
sock
,
addr
,
load_library
,
silent
):
"""Server loop"""
"""Server loop"""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
sockfd
=
sock
.
fileno
()
sockfd
=
sock
.
fileno
()
temp
=
_server_env
(
load_library
)
temp
=
_server_env
(
load_library
,
logger
)
base
.
_ServerLoop
(
sockfd
)
base
.
_ServerLoop
(
sockfd
)
temp
.
remove
()
temp
.
remove
()
logg
ing
.
info
(
"Finish serving
%
s"
,
addr
)
logg
er
.
info
(
"Finish serving
%
s"
,
addr
)
def
_parse_server_opt
(
opts
):
def
_parse_server_opt
(
opts
):
...
@@ -71,8 +78,12 @@ def _parse_server_opt(opts):
...
@@ -71,8 +78,12 @@ def _parse_server_opt(opts):
ret
[
"timeout"
]
=
float
(
kv
[
9
:])
ret
[
"timeout"
]
=
float
(
kv
[
9
:])
return
ret
return
ret
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
):
def
_listen_loop
(
sock
,
port
,
rpc_key
,
tracker_addr
,
load_library
,
custom_addr
,
silent
):
"""Lisenting loop of the server master."""
"""Listening loop of the server master."""
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
logger
.
disabled
=
True
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
def
_accept_conn
(
listen_sock
,
tracker_conn
,
ping_period
=
2
):
"""Accept connection from the other places.
"""Accept connection from the other places.
...
@@ -115,7 +126,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
...
@@ -115,7 +126,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
unmatch_period_count
=
0
unmatch_period_count
=
0
# regenerate match key if key is acquired but not used for a while
# regenerate match key if key is acquired but not used for a while
if
unmatch_period_count
*
ping_period
>
unmatch_timeout
+
ping_period
:
if
unmatch_period_count
*
ping_period
>
unmatch_timeout
+
ping_period
:
logg
ing
.
info
(
"RPCServer:
no incoming connections, regenerate key ..."
)
logg
er
.
info
(
"
no incoming connections, regenerate key ..."
)
matchkey
=
base
.
random_key
(
rpc_key
+
":"
,
old_keyset
)
matchkey
=
base
.
random_key
(
rpc_key
+
":"
,
old_keyset
)
base
.
sendjson
(
tracker_conn
,
base
.
sendjson
(
tracker_conn
,
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
),
[
TrackerCode
.
PUT
,
rpc_key
,
(
port
,
matchkey
),
...
@@ -136,7 +147,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
...
@@ -136,7 +147,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
if
arr
[
0
]
!=
expect_header
:
if
arr
[
0
]
!=
expect_header
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_MISMATCH
))
conn
.
close
()
conn
.
close
()
logg
ing
.
info
(
"RPCServer:
mismatch key from
%
s"
,
addr
)
logg
er
.
info
(
"
mismatch key from
%
s"
,
addr
)
continue
continue
else
:
else
:
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_SUCCESS
))
conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_CODE_SUCCESS
))
...
@@ -150,7 +161,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
...
@@ -150,7 +161,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
try
:
try
:
# step 1: setup tracker and report to tracker
# step 1: setup tracker and report to tracker
if
tracker_addr
and
tracker_conn
is
None
:
if
tracker_addr
and
tracker_conn
is
None
:
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
)
tracker_conn
=
base
.
connect_with_retry
(
tracker_addr
,
silent
=
silent
)
tracker_conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_TRACKER_MAGIC
))
tracker_conn
.
sendall
(
struct
.
pack
(
"<i"
,
base
.
RPC_TRACKER_MAGIC
))
magic
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
magic
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
tracker_conn
,
4
))[
0
]
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
if
magic
!=
base
.
RPC_TRACKER_MAGIC
:
...
@@ -169,10 +180,16 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
...
@@ -169,10 +180,16 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
tracker_conn
.
close
()
tracker_conn
.
close
()
tracker_conn
=
None
tracker_conn
=
None
continue
continue
except
RuntimeError
as
exc
:
if
silent
:
return
else
:
raise
exc
# step 3: serving
# step 3: serving
logging
.
info
(
"RPCServer: connection from
%
s"
,
addr
)
logger
.
info
(
"connection from
%
s"
,
addr
)
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
,
load_library
))
server_proc
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
conn
,
addr
,
load_library
,
silent
))
server_proc
.
deamon
=
True
server_proc
.
deamon
=
True
server_proc
.
start
()
server_proc
.
start
()
# close from our side.
# close from our side.
...
@@ -180,11 +197,14 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
...
@@ -180,11 +197,14 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
# wait until server process finish or timeout
# wait until server process finish or timeout
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
server_proc
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
server_proc
.
is_alive
():
if
server_proc
.
is_alive
():
logg
ing
.
info
(
"RPCServer:
Timeout in RPC session, kill.."
)
logg
er
.
info
(
"
Timeout in RPC session, kill.."
)
server_proc
.
terminate
()
server_proc
.
terminate
()
def
_connect_proxy_loop
(
addr
,
key
,
load_library
):
def
_connect_proxy_loop
(
addr
,
key
,
load_library
,
silent
):
logger
=
logging
.
getLogger
(
"RPCProxy"
)
if
silent
:
logger
.
disabled
=
True
key
=
"server:"
+
key
key
=
"server:"
+
key
retry_count
=
0
retry_count
=
0
max_retry
=
5
max_retry
=
5
...
@@ -200,26 +220,26 @@ def _connect_proxy_loop(addr, key, load_library):
...
@@ -200,26 +220,26 @@ def _connect_proxy_loop(addr, key, load_library):
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
if
magic
==
base
.
RPC_CODE_DUPLICATE
:
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
raise
RuntimeError
(
"key:
%
s has already been used in proxy"
%
key
)
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
elif
magic
==
base
.
RPC_CODE_MISMATCH
:
logg
ing
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
logg
er
.
info
(
"RPCProxy do not have matching client key
%
s"
,
key
)
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
elif
magic
!=
base
.
RPC_CODE_SUCCESS
:
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
raise
RuntimeError
(
"
%
s is not RPC Proxy"
%
str
(
addr
))
keylen
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
sock
,
4
))[
0
]
keylen
=
struct
.
unpack
(
"<i"
,
base
.
recvall
(
sock
,
4
))[
0
]
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
remote_key
=
py_str
(
base
.
recvall
(
sock
,
keylen
))
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
opts
=
_parse_server_opt
(
remote_key
.
split
()[
1
:])
logg
ing
.
info
(
"RPCProxy
connected to
%
s"
,
str
(
addr
))
logg
er
.
info
(
"
connected to
%
s"
,
str
(
addr
))
process
=
multiprocessing
.
Process
(
process
=
multiprocessing
.
Process
(
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
))
target
=
_serve_loop
,
args
=
(
sock
,
addr
,
load_library
,
silent
))
process
.
deamon
=
True
process
.
deamon
=
True
process
.
start
()
process
.
start
()
sock
.
close
()
sock
.
close
()
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
process
.
join
(
opts
.
get
(
"timeout"
,
None
))
if
process
.
is_alive
():
if
process
.
is_alive
():
logg
ing
.
info
(
"RPCProxyServer:
Timeout in RPC session, kill.."
)
logg
er
.
info
(
"
Timeout in RPC session, kill.."
)
process
.
terminate
()
process
.
terminate
()
retry_count
=
0
retry_count
=
0
except
(
socket
.
error
,
IOError
)
as
err
:
except
(
socket
.
error
,
IOError
)
as
err
:
retry_count
+=
1
retry_count
+=
1
logg
ing
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
logg
er
.
info
(
"Error encountered
%
s, retry in
%
g sec"
,
str
(
err
),
retry_period
)
if
retry_count
>
max_retry
:
if
retry_count
>
max_retry
:
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
raise
RuntimeError
(
"Maximum retry error: last error:
%
s"
%
str
(
err
))
time
.
sleep
(
retry_period
)
time
.
sleep
(
retry_period
)
...
@@ -264,6 +284,9 @@ class Server(object):
...
@@ -264,6 +284,9 @@ class Server(object):
This is recommended to switch on if we want to do local RPC demonstration
This is recommended to switch on if we want to do local RPC demonstration
for GPU devices to avoid fork safety issues.
for GPU devices to avoid fork safety issues.
silent: bool, optional
Whether run this server in silent mode.
key : str, optional
key : str, optional
The key used to identify the server in Proxy connection.
The key used to identify the server in Proxy connection.
...
@@ -276,6 +299,7 @@ class Server(object):
...
@@ -276,6 +299,7 @@ class Server(object):
port_end
=
9199
,
port_end
=
9199
,
is_proxy
=
False
,
is_proxy
=
False
,
use_popen
=
False
,
use_popen
=
False
,
silent
=
False
,
tracker_addr
=
None
,
tracker_addr
=
None
,
key
=
""
,
key
=
""
,
load_library
=
None
,
load_library
=
None
,
...
@@ -290,8 +314,12 @@ class Server(object):
...
@@ -290,8 +314,12 @@ class Server(object):
self
.
libs
=
[]
self
.
libs
=
[]
self
.
custom_addr
=
custom_addr
self
.
custom_addr
=
custom_addr
self
.
logger
=
logging
.
getLogger
(
"RPCServer"
)
if
silent
:
self
.
logger
.
disabled
=
True
if
use_popen
:
if
use_popen
:
cmd
=
[
"python"
,
cmd
=
[
sys
.
executable
,
"-m"
,
"tvm.exec.rpc_server"
,
"-m"
,
"tvm.exec.rpc_server"
,
"--host=
%
s"
%
host
,
"--host=
%
s"
%
host
,
"--port=
%
s"
%
port
]
"--port=
%
s"
%
port
]
...
@@ -303,11 +331,14 @@ class Server(object):
...
@@ -303,11 +331,14 @@ class Server(object):
cmd
+=
[
"--load-library"
,
load_library
]
cmd
+=
[
"--load-library"
,
load_library
]
if
custom_addr
:
if
custom_addr
:
cmd
+=
[
"--custom-addr"
,
custom_addr
]
cmd
+=
[
"--custom-addr"
,
custom_addr
]
if
silent
:
cmd
+=
[
"--silent"
]
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
subprocess
.
check_call
,
args
=
(
cmd
,))
target
=
subprocess
.
check_call
,
args
=
(
cmd
,))
self
.
proc
.
deamon
=
True
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
self
.
proc
.
start
()
time
.
sleep
(
1
)
time
.
sleep
(
0.5
)
elif
not
is_proxy
:
elif
not
is_proxy
:
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
port
=
None
...
@@ -323,17 +354,18 @@ class Server(object):
...
@@ -323,17 +354,18 @@ class Server(object):
raise
sock_err
raise
sock_err
if
not
self
.
port
:
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
logging
.
info
(
"RPCServer:
bind to
%
s:
%
d"
,
host
,
self
.
port
)
self
.
logger
.
info
(
"
bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
sock
.
listen
(
1
)
self
.
sock
=
sock
self
.
sock
=
sock
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
_listen_loop
,
args
=
(
target
=
_listen_loop
,
args
=
(
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
,
self
.
custom_addr
))
self
.
sock
,
self
.
port
,
key
,
tracker_addr
,
load_library
,
self
.
custom_addr
,
silent
))
self
.
proc
.
deamon
=
True
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
self
.
proc
.
start
()
else
:
else
:
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
))
target
=
_connect_proxy_loop
,
args
=
((
host
,
port
),
key
,
load_library
,
silent
))
self
.
proc
.
deamon
=
True
self
.
proc
.
deamon
=
True
self
.
proc
.
start
()
self
.
proc
.
start
()
...
...
python/tvm/contrib/rpc/tracker.py
View file @
d0eb2d3d
...
@@ -309,7 +309,6 @@ class TrackerServerHandler(object):
...
@@ -309,7 +309,6 @@ class TrackerServerHandler(object):
def
_tracker_server
(
listen_sock
,
stop_key
):
def
_tracker_server
(
listen_sock
,
stop_key
):
handler
=
TrackerServerHandler
(
listen_sock
,
stop_key
)
handler
=
TrackerServerHandler
(
listen_sock
,
stop_key
)
handler
.
run
()
handler
.
run
()
logging
.
info
(
"Tracker Stop signal received, terminating..."
)
class
Tracker
(
object
):
class
Tracker
(
object
):
...
@@ -327,11 +326,19 @@ class Tracker(object):
...
@@ -327,11 +326,19 @@ class Tracker(object):
port_end : int, optional
port_end : int, optional
The end TCP port to search
The end TCP port to search
silent: bool, optional
Whether run in silent mode
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
host
,
host
,
port
=
9190
,
port
=
9190
,
port_end
=
9199
):
port_end
=
9199
,
silent
=
False
):
self
.
logger
=
logging
.
getLogger
(
"RPCTracker"
)
if
silent
:
self
.
logger
.
disabled
=
True
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
sock
=
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
self
.
port
=
None
self
.
port
=
None
self
.
stop_key
=
base
.
random_key
(
"tracker"
)
self
.
stop_key
=
base
.
random_key
(
"tracker"
)
...
@@ -347,7 +354,7 @@ class Tracker(object):
...
@@ -347,7 +354,7 @@ class Tracker(object):
raise
sock_err
raise
sock_err
if
not
self
.
port
:
if
not
self
.
port
:
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
raise
ValueError
(
"cannot bind to any port in [
%
d,
%
d)"
%
(
port
,
port_end
))
logging
.
info
(
"RPCTracker:
bind to
%
s:
%
d"
,
host
,
self
.
port
)
self
.
logger
.
info
(
"
bind to
%
s:
%
d"
,
host
,
self
.
port
)
sock
.
listen
(
1
)
sock
.
listen
(
1
)
self
.
proc
=
multiprocessing
.
Process
(
self
.
proc
=
multiprocessing
.
Process
(
target
=
_tracker_server
,
args
=
(
sock
,
self
.
stop_key
))
target
=
_tracker_server
,
args
=
(
sock
,
self
.
stop_key
))
...
@@ -373,7 +380,7 @@ class Tracker(object):
...
@@ -373,7 +380,7 @@ class Tracker(object):
self
.
_stop_tracker
()
self
.
_stop_tracker
()
self
.
proc
.
join
(
1
)
self
.
proc
.
join
(
1
)
if
self
.
proc
.
is_alive
():
if
self
.
proc
.
is_alive
():
logging
.
info
(
"Terminating Tracker Server..."
)
self
.
logger
.
info
(
"Terminating Tracker Server..."
)
self
.
proc
.
terminate
()
self
.
proc
.
terminate
()
self
.
proc
=
None
self
.
proc
=
None
...
...
python/tvm/exec/rpc_server.py
View file @
d0eb2d3d
...
@@ -27,7 +27,8 @@ def main(args):
...
@@ -27,7 +27,8 @@ def main(args):
key
=
args
.
key
,
key
=
args
.
key
,
tracker_addr
=
tracker_addr
,
tracker_addr
=
tracker_addr
,
load_library
=
args
.
load_library
,
load_library
=
args
.
load_library
,
custom_addr
=
args
.
custom_addr
)
custom_addr
=
args
.
custom_addr
,
silent
=
args
.
silent
)
server
.
proc
.
join
()
server
.
proc
.
join
()
...
@@ -51,6 +52,8 @@ if __name__ == "__main__":
...
@@ -51,6 +52,8 @@ if __name__ == "__main__":
and ROCM compilers."
)
and ROCM compilers."
)
parser
.
add_argument
(
'--custom-addr'
,
type
=
str
,
parser
.
add_argument
(
'--custom-addr'
,
type
=
str
,
help
=
"Custom IP Address to Report to RPC Tracker"
)
help
=
"Custom IP Address to Report to RPC Tracker"
)
parser
.
add_argument
(
'--silent'
,
action
=
'store_true'
,
help
=
"Whether run in silent mode."
)
parser
.
set_defaults
(
fork
=
True
)
parser
.
set_defaults
(
fork
=
True
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -62,6 +65,7 @@ if __name__ == "__main__":
...
@@ -62,6 +65,7 @@ if __name__ == "__main__":
)
)
multiprocessing
.
set_start_method
(
'spawn'
)
multiprocessing
.
set_start_method
(
'spawn'
)
else
:
else
:
logging
.
info
(
"If you are running ROCM/Metal,
\
if
not
args
.
silent
:
fork with cause compiler internal error. Try to launch with arg ```--no-fork```"
)
logging
.
info
(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main
(
args
)
main
(
args
)
python/tvm/exec/rpc_tracker.py
View file @
d0eb2d3d
...
@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker
...
@@ -11,7 +11,8 @@ from ..contrib.rpc.tracker import Tracker
def
main
(
args
):
def
main
(
args
):
"""Main funciton"""
"""Main funciton"""
tracker
=
Tracker
(
args
.
host
,
port
=
args
.
port
)
tracker
=
Tracker
(
args
.
host
,
port
=
args
.
port
,
port_end
=
args
.
port_end
,
silent
=
args
.
silent
)
tracker
.
proc
.
join
()
tracker
.
proc
.
join
()
...
@@ -21,10 +22,15 @@ if __name__ == "__main__":
...
@@ -21,10 +22,15 @@ if __name__ == "__main__":
help
=
'the hostname of the tracker'
)
help
=
'the hostname of the tracker'
)
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9190
,
parser
.
add_argument
(
'--port'
,
type
=
int
,
default
=
9190
,
help
=
'The port of the PRC'
)
help
=
'The port of the PRC'
)
parser
.
add_argument
(
'--port-end'
,
type
=
int
,
default
=
9199
,
help
=
'The end search port of the PRC'
)
parser
.
add_argument
(
'--no-fork'
,
dest
=
'fork'
,
action
=
'store_false'
,
parser
.
add_argument
(
'--no-fork'
,
dest
=
'fork'
,
action
=
'store_false'
,
help
=
"Use spawn mode to avoid fork. This option
\
help
=
"Use spawn mode to avoid fork. This option
\
is able to avoid potential fork problems with Metal, OpenCL
\
is able to avoid potential fork problems with Metal, OpenCL
\
and ROCM compilers."
)
and ROCM compilers."
)
parser
.
add_argument
(
'--silent'
,
action
=
'store_true'
,
help
=
"Whether run in silent mode."
)
parser
.
set_defaults
(
fork
=
True
)
parser
.
set_defaults
(
fork
=
True
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging
.
basicConfig
(
level
=
logging
.
INFO
)
logging
.
basicConfig
(
level
=
logging
.
INFO
)
...
@@ -35,6 +41,7 @@ if __name__ == "__main__":
...
@@ -35,6 +41,7 @@ if __name__ == "__main__":
)
)
multiprocessing
.
set_start_method
(
'spawn'
)
multiprocessing
.
set_start_method
(
'spawn'
)
else
:
else
:
logging
.
info
(
"If you are running ROCM/Metal,
\
if
not
args
.
silent
:
fork with cause compiler internal error. Try to launch with arg ```--no-fork```"
)
logging
.
info
(
"If you are running ROCM/Metal, fork will cause "
"compiler internal error. Try to launch with arg ```--no-fork```"
)
main
(
args
)
main
(
args
)
src/runtime/rpc/rpc_server_env.cc
View file @
d0eb2d3d
...
@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
...
@@ -20,7 +20,6 @@ TVM_REGISTER_GLOBAL("tvm.contrib.rpc.server.upload").
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
set_body
([](
TVMArgs
args
,
TVMRetValue
*
rv
)
{
std
::
string
file_name
=
RPCGetPath
(
args
[
0
]);
std
::
string
file_name
=
RPCGetPath
(
args
[
0
]);
std
::
string
data
=
args
[
1
];
std
::
string
data
=
args
[
1
];
LOG
(
INFO
)
<<
"Upload "
<<
file_name
<<
"... nbytes="
<<
data
.
length
();
SaveBinaryToFile
(
file_name
,
data
);
SaveBinaryToFile
(
file_name
,
data
);
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment