Commit 6b5fbdad by Balint Cristian Committed by Tianqi Chen

[RPC] Better handle tempdir if subprocess killed. (#3574)

parent b0481c82
...@@ -30,7 +30,11 @@ class TempDirectory(object): ...@@ -30,7 +30,11 @@ class TempDirectory(object):
Automatically removes the directory when it went out of scope. Automatically removes the directory when it went out of scope.
""" """
def __init__(self): def __init__(self, custom_path=None):
if custom_path:
os.mkdir(custom_path)
self.temp_dir = custom_path
else:
self.temp_dir = tempfile.mkdtemp() self.temp_dir = tempfile.mkdtemp()
self._rmtree = shutil.rmtree self._rmtree = shutil.rmtree
...@@ -69,15 +73,20 @@ class TempDirectory(object): ...@@ -69,15 +73,20 @@ class TempDirectory(object):
return os.listdir(self.temp_dir) return os.listdir(self.temp_dir)
def tempdir(): def tempdir(custom_path=None):
"""Create temp dir which deletes the contents when exit. """Create temp dir which deletes the contents when exit.
Parameters
----------
custom_path : str, optional
Manually specify the exact temp dir path
Returns Returns
------- -------
temp : TempDirectory temp : TempDirectory
The temp directory object The temp directory object
""" """
return TempDirectory() return TempDirectory(custom_path)
class FileLock(object): class FileLock(object):
......
...@@ -251,7 +251,7 @@ def load(path, fmt=""): ...@@ -251,7 +251,7 @@ def load(path, fmt=""):
_cc.create_shared(path + ".so", path) _cc.create_shared(path + ".so", path)
path += ".so" path += ".so"
elif path.endswith(".tar"): elif path.endswith(".tar"):
tar_temp = _util.tempdir() tar_temp = _util.tempdir(custom_path=path.replace('.tar', ''))
_tar.untar(path, tar_temp.temp_dir) _tar.untar(path, tar_temp.temp_dir)
files = [tar_temp.relpath(x) for x in tar_temp.listdir()] files = [tar_temp.relpath(x) for x in tar_temp.listdir()]
_cc.create_shared(path + ".so", files) _cc.create_shared(path + ".so", files)
......
...@@ -50,8 +50,11 @@ from . base import TrackerCode ...@@ -50,8 +50,11 @@ from . base import TrackerCode
logger = logging.getLogger('RPCServer') logger = logging.getLogger('RPCServer')
def _server_env(load_library): def _server_env(load_library, work_path=None):
"""Server environment function return temp dir""" """Server environment function return temp dir"""
if work_path:
temp = work_path
else:
temp = util.tempdir() temp = util.tempdir()
# pylint: disable=unused-variable # pylint: disable=unused-variable
...@@ -76,16 +79,15 @@ def _server_env(load_library): ...@@ -76,16 +79,15 @@ def _server_env(load_library):
temp.libs = libs temp.libs = libs
return temp return temp
def _serve_loop(sock, addr, load_library, work_path=None):
def _serve_loop(sock, addr, load_library):
"""Server loop""" """Server loop"""
sockfd = sock.fileno() sockfd = sock.fileno()
temp = _server_env(load_library) temp = _server_env(load_library, work_path)
base._ServerLoop(sockfd) base._ServerLoop(sockfd)
if not work_path:
temp.remove() temp.remove()
logger.info("Finish serving %s", addr) logger.info("Finish serving %s", addr)
def _parse_server_opt(opts): def _parse_server_opt(opts):
# parse client options # parse client options
ret = {} ret = {}
...@@ -196,9 +198,10 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -196,9 +198,10 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
raise exc raise exc
# step 3: serving # step 3: serving
work_path = util.tempdir()
logger.info("connection from %s", addr) logger.info("connection from %s", addr)
server_proc = multiprocessing.Process(target=_serve_loop, server_proc = multiprocessing.Process(target=_serve_loop,
args=(conn, addr, load_library)) args=(conn, addr, load_library, work_path))
server_proc.deamon = True server_proc.deamon = True
server_proc.start() server_proc.start()
# close from our side. # close from our side.
...@@ -208,6 +211,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr): ...@@ -208,6 +211,7 @@ def _listen_loop(sock, port, rpc_key, tracker_addr, load_library, custom_addr):
if server_proc.is_alive(): if server_proc.is_alive():
logger.info("Timeout in RPC session, kill..") logger.info("Timeout in RPC session, kill..")
server_proc.terminate() server_proc.terminate()
work_path.remove()
def _connect_proxy_loop(addr, key, load_library): def _connect_proxy_loop(addr, key, load_library):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment