Skip to content

Commit

Permalink
Refactor to use context manager
Browse files Browse the repository at this point in the history
Simplifies the logic and should help avoid mistakes in the future.

Co-authored-by: Gregory P. Smith <greg@krypto.org>
  • Loading branch information
cptpcrd and gpshead committed Oct 2, 2022
1 parent 2215bad commit 99c03e9
Showing 1 changed file with 31 additions and 49 deletions.
80 changes: 31 additions & 49 deletions Lib/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,6 +1306,26 @@ def _close_pipe_fds(self,
# Prevent a double close of these handles/fds from __init__ on error.
self._closed_child_pipe_fds = True

@contextlib.contextmanager
def _on_error_fd_closer(self):
"""Helper to ensure file descriptors opened in _get_handles are closed"""
to_close = []
try:
yield to_close
except:
if hasattr(self, '_devnull'):
to_close.append(self._devnull)
del self._devnull
for fd in to_close:
try:
if _mswindows and isinstance(fd, Handle):
fd.Close()
else:
os.close(fd)
except OSError:
pass
raise

if _mswindows:
#
# Windows methods
Expand All @@ -1321,22 +1341,18 @@ def _get_handles(self, stdin, stdout, stderr):
c2pread, c2pwrite = -1, -1
errread, errwrite = -1, -1

stdin_needsclose = False
stdout_needsclose = False
stderr_needsclose = False

try:
with self._on_error_fd_closer() as err_close_fds:
if stdin is None:
p2cread = _winapi.GetStdHandle(_winapi.STD_INPUT_HANDLE)
if p2cread is None:
stdin_needsclose = True
p2cread, _ = _winapi.CreatePipe(None, 0)
p2cread = Handle(p2cread)
_winapi.CloseHandle(_)
err_close_fds.append(p2cread)
elif stdin == PIPE:
stdin_needsclose = True
p2cread, p2cwrite = _winapi.CreatePipe(None, 0)
p2cread, p2cwrite = Handle(p2cread), Handle(p2cwrite)
err_close_fds.extend((p2cread, p2cwrite))
elif stdin == DEVNULL:
p2cread = msvcrt.get_osfhandle(self._get_devnull())
elif isinstance(stdin, int):
Expand All @@ -1349,14 +1365,14 @@ def _get_handles(self, stdin, stdout, stderr):
if stdout is None:
c2pwrite = _winapi.GetStdHandle(_winapi.STD_OUTPUT_HANDLE)
if c2pwrite is None:
stdout_needsclose = True
_, c2pwrite = _winapi.CreatePipe(None, 0)
c2pwrite = Handle(c2pwrite)
_winapi.CloseHandle(_)
err_close_fds.append(c2pwrite)
elif stdout == PIPE:
stdout_needsclose = True
c2pread, c2pwrite = _winapi.CreatePipe(None, 0)
c2pread, c2pwrite = Handle(c2pread), Handle(c2pwrite)
err_close_fds.extend((c2pread, c2pwrite))
elif stdout == DEVNULL:
c2pwrite = msvcrt.get_osfhandle(self._get_devnull())
elif isinstance(stdout, int):
Expand All @@ -1369,14 +1385,14 @@ def _get_handles(self, stdin, stdout, stderr):
if stderr is None:
errwrite = _winapi.GetStdHandle(_winapi.STD_ERROR_HANDLE)
if errwrite is None:
stderr_needsclose = True
_, errwrite = _winapi.CreatePipe(None, 0)
errwrite = Handle(errwrite)
_winapi.CloseHandle(_)
err_close_fds.append(errwrite)
elif stderr == PIPE:
stderr_needsclose = True
errread, errwrite = _winapi.CreatePipe(None, 0)
errread, errwrite = Handle(errread), Handle(errwrite)
err_close_fds.extend((errread, errwrite))
elif stderr == STDOUT:
errwrite = c2pwrite
elif stderr == DEVNULL:
Expand All @@ -1388,27 +1404,6 @@ def _get_handles(self, stdin, stdout, stderr):
errwrite = msvcrt.get_osfhandle(stderr.fileno())
errwrite = self._make_inheritable(errwrite)

except BaseException:
to_close = []
if stdin_needsclose and p2cwrite != -1:
to_close.append(p2cread)
to_close.append(p2cwrite)
if stdout_needsclose and p2cwrite != -1:
to_close.append(c2pread)
to_close.append(c2pwrite)
if stderr_needsclose and errwrite != -1:
to_close.append(errread)
to_close.append(errwrite)
for file in to_close:
if isinstance(file, Handle):
file.Close()
else:
os.close(file)
if hasattr(self, "_devnull"):
os.close(self._devnull)
del self._devnull
raise

return (p2cread, p2cwrite,
c2pread, c2pwrite,
errread, errwrite)
Expand Down Expand Up @@ -1678,13 +1673,14 @@ def _get_handles(self, stdin, stdout, stderr):
c2pread, c2pwrite = -1, -1
errread, errwrite = -1, -1

try:
with self._on_error_fd_closer() as err_close_fds:
if stdin is None:
pass
elif stdin == PIPE:
p2cread, p2cwrite = os.pipe()
if self.pipesize > 0 and hasattr(fcntl, "F_SETPIPE_SZ"):
fcntl.fcntl(p2cwrite, fcntl.F_SETPIPE_SZ, self.pipesize)
err_close_fds.extend((p2cread, p2cwrite))
elif stdin == DEVNULL:
p2cread = self._get_devnull()
elif isinstance(stdin, int):
Expand All @@ -1699,6 +1695,7 @@ def _get_handles(self, stdin, stdout, stderr):
c2pread, c2pwrite = os.pipe()
if self.pipesize > 0 and hasattr(fcntl, "F_SETPIPE_SZ"):
fcntl.fcntl(c2pwrite, fcntl.F_SETPIPE_SZ, self.pipesize)
err_close_fds.extend((c2pread, c2pwrite))
elif stdout == DEVNULL:
c2pwrite = self._get_devnull()
elif isinstance(stdout, int):
Expand All @@ -1713,6 +1710,7 @@ def _get_handles(self, stdin, stdout, stderr):
errread, errwrite = os.pipe()
if self.pipesize > 0 and hasattr(fcntl, "F_SETPIPE_SZ"):
fcntl.fcntl(errwrite, fcntl.F_SETPIPE_SZ, self.pipesize)
err_close_fds.extend((errread, errwrite))
elif stderr == STDOUT:
if c2pwrite != -1:
errwrite = c2pwrite
Expand All @@ -1726,22 +1724,6 @@ def _get_handles(self, stdin, stdout, stderr):
# Assuming file-like object
errwrite = stderr.fileno()

except BaseException:
# Close the file descriptors we opened to avoid leakage
if stdin == PIPE and p2cwrite != -1:
os.close(p2cread)
os.close(p2cwrite)
if stdout == PIPE and c2pwrite != -1:
os.close(c2pread)
os.close(c2pwrite)
if stderr == PIPE and errwrite != -1:
os.close(errread)
os.close(errwrite)
if hasattr(self, "_devnull"):
os.close(self._devnull)
del self._devnull
raise

return (p2cread, p2cwrite,
c2pread, c2pwrite,
errread, errwrite)
Expand Down

0 comments on commit 99c03e9

Please sign in to comment.