Commit 6b5fbdad by Balint Cristian Committed by Tianqi Chen

[RPC] Better handle tempdir if subprocess killed. (#3574)

parent b0481c82
......@@ -30,8 +30,12 @@ class TempDirectory(object):
Automatically removes the directory when it went out of scope.
"""
def __init__(self):
self.temp_dir = tempfile.mkdtemp()
def __init__(self, custom_path=None):
if custom_path:
os.mkdir(custom_path)
self.temp_dir = custom_path
else:
self.temp_dir = tempfile.mkdtemp()
self._rmtree = shutil.rmtree
def remove(self):
......@@ -69,15 +73,20 @@ class TempDirectory(object):
return os.listdir(self.temp_dir)
def tempdir():
def tempdir(custom_path=None):
"""Create temp dir which deletes the contents when exit.
Parameters
----------
custom_path : str, optional
Manually specify the exact temp dir path
Returns
-------
temp : TempDirectory
The temp directory object
"""
return TempDirectory()
return TempDirectory(custom_path)
class FileLock(object):
......
......@@ -251,7 +251,7 @@ def load(path, fmt=""):
_cc.create_shared(path + ".so", path)
path += ".so"
elif path.endswith(".tar"):
tar_temp = _util.tempdir()
tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
_tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files)
......
......@@ -50,9 +50,12 @@ from . base import TrackerCode
logger = logging.getLogger('RPCServer')
def _server_env(load_library):
def _server_env(load_library, work_path=None):
"""Server environment function return temp dir"""
temp = util.tempdir()
if work_path:
temp = work_path
else:
temp = util.tempdir()
# pylint: disable=unused-variable
@register_func("tvm.rpc.server.workpath")
......@@ -76,16 +79,15 @@ def _server_env(load_library):
temp.libs = libs
return temp
def _serve_loop(sock, addr, load_library):
def _serve_loop(sock, addr, load_library, work_path=None):
"""Server loop"""
sockfd = sock.fileno()
temp = _server_env(load_library)
temp = _server_env(load_library, work_path)
base._ServerLoop(sockfd)
temp.remove()
if not work_path:
temp.remove()
logger.info("Finish serving %s", addr)
def _parse_server_opt(opts):
# parse client options
ret = {}
......@@ -196,9 +198,10 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
raise exc
# step 3: serving
work_path = util.tempdir()
logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library))
args=(conn, addr, load_library, work_path))
server_proc.deamon = True
server_proc.start()
# close from our side.
......@@ -208,6 +211,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
if server_proc.is_alive():
logger.info("Timeout in RPC session, kill..")
server_proc.terminate()
work_path.remove()
def _connect_proxy_loop(addr, key, load_library):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment