Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better graceful shutdown for KeyboardInterrupt #19976

Merged
merged 16 commits into from
Jun 16, 2024
4 changes: 4 additions & 0 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import contextlib
import logging
import os
import signal
import time
from contextlib import nullcontext
from datetime import timedelta
Expand Down Expand Up @@ -306,8 +307,11 @@ def _init_dist_connection(


def _destroy_dist_connection() -> None:
# Don't allow Ctrl+C to interrupt this handler
signal.signal(signal.SIGINT, signal.SIG_IGN)
if _distributed_is_initialized():
torch.distributed.destroy_process_group()
signal.signal(signal.SIGINT, signal.SIG_DFL)


def _get_default_process_group_backend_for_device(device: torch.device) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, An
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
if proc.is_alive() and proc.pid is not None:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
with suppress(ProcessLookupError):
os.kill(proc.pid, signum)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
@override
def kill(self, signum: _SIGNUM) -> None:
for proc in self.procs:
log.info(f"pid {os.getpid()} killing {proc.pid} with {signum}")
log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
# this skips subprocesses already terminated
proc.send_signal(signum)

Expand Down
19 changes: 13 additions & 6 deletions src/lightning/pytorch/trainer/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import signal
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Type, Union

Expand All @@ -20,10 +21,11 @@
import lightning.pytorch as pl
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.pytorch.callbacks import Checkpoint, EarlyStopping
from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher
from lightning.pytorch.trainer.states import TrainerStatus
from lightning.pytorch.utilities.exceptions import _TunerExitException
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn

log = logging.getLogger(__name__)

Expand All @@ -49,12 +51,17 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg
trainer.state.status = TrainerStatus.FINISHED
trainer.state.stage = None

# TODO: Unify both exceptions below, where `KeyboardError` doesn't re-raise
except KeyboardInterrupt as exception:
rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
# user could press Ctrl+c many times... only shutdown once
if not trainer.interrupted:
_interrupt(trainer, exception)
rank_zero_info("\nDetected KeyboardInterrupt, attempting graceful shutdown ...")
# user could press Ctrl+C many times, disable KeyboardInterrupt for shutdown
signal.signal(signal.SIGINT, signal.SIG_IGN)
_interrupt(trainer, exception)
trainer._teardown()
launcher = trainer.strategy.launcher
if isinstance(launcher, _SubprocessScriptLauncher):
launcher.kill(signal.SIGKILL)
exit(1)

except BaseException as exception:
_interrupt(trainer, exception)
trainer._teardown()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def on_train_start(self) -> None:

with mock.patch(
"lightning.pytorch.callbacks.progress.rich_progress.Progress.stop", autospec=True
) as mock_progress_stop:
) as mock_progress_stop, pytest.raises(SystemExit):
progress_bar = RichProgressBar()
trainer = Trainer(
default_root_dir=tmp_path,
Expand Down
9 changes: 7 additions & 2 deletions tests/tests_pytorch/callbacks/test_lambda_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from functools import partial

import pytest
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, LambdaCallback
from lightning.pytorch.demos.boring_classes import BoringModel
Expand All @@ -23,10 +24,13 @@
def test_lambda_call(tmp_path):
seed_everything(42)

class CustomException(Exception):
pass

class CustomModel(BoringModel):
def on_train_epoch_start(self):
if self.current_epoch > 1:
raise KeyboardInterrupt
raise CustomException("Custom exception to trigger `on_exception` hooks")

checker = set()

Expand Down Expand Up @@ -59,7 +63,8 @@ def call(hook, *_, **__):
limit_predict_batches=1,
callbacks=[LambdaCallback(**hooks_args)],
)
trainer.fit(model, ckpt_path=ckpt_path)
with pytest.raises(CustomException):
trainer.fit(model, ckpt_path=ckpt_path)
trainer.test(model)
trainer.predict(model)

Expand Down
3 changes: 2 additions & 1 deletion tests/tests_pytorch/trainer/test_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,6 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx):

trainer = Trainer(callbacks=[InterruptCallback()], default_root_dir=tmp_path, **extra_params)

trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted
7 changes: 4 additions & 3 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,8 @@ def on_exception(self, trainer, pl_module, exception):
)
assert not trainer.interrupted
assert handle_interrupt_callback.exception is None
trainer.fit(model)
with pytest.raises(SystemExit):
trainer.fit(model)
assert trainer.interrupted
assert isinstance(handle_interrupt_callback.exception, KeyboardInterrupt)
with pytest.raises(MisconfigurationException):
Expand Down Expand Up @@ -2042,7 +2043,7 @@ def on_fit_start(self):

trainer = Trainer(default_root_dir=tmp_path)
with mock.patch("lightning.pytorch.strategies.strategy.Strategy.on_exception") as on_exception_mock, suppress(
Exception
Exception, SystemExit
):
trainer.fit(ExceptionModel())
on_exception_mock.assert_called_once_with(exception)
Expand All @@ -2061,7 +2062,7 @@ def on_fit_start(self):
datamodule.on_exception = Mock()
trainer = Trainer(default_root_dir=tmp_path)

with suppress(Exception):
with suppress(Exception, SystemExit):
trainer.fit(ExceptionModel(), datamodule=datamodule)
datamodule.on_exception.assert_called_once_with(exception)

Expand Down
Loading