Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6a1bbf1
Update signal_connector.py
KAVYANSHTYAGI May 13, 2025
2761ad8
Update training_epoch_loop.py
KAVYANSHTYAGI May 13, 2025
93c3e69
Create test_ddp_sigterm_handling.py
KAVYANSHTYAGI May 13, 2025
b7cef51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 13, 2025
2fc5178
Update training_epoch_loop.py
KAVYANSHTYAGI May 15, 2025
c8c9523
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2025
ebbe682
Update test_ddp_sigterm_handling.py
KAVYANSHTYAGI May 15, 2025
f327aa7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2025
f50b3a9
Update training_epoch_loop.py
KAVYANSHTYAGI May 15, 2025
49e2fab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 15, 2025
5600bff
Update training_epoch_loop.py
KAVYANSHTYAGI May 16, 2025
873a792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2025
4a275da
Update test_ddp_sigterm_handling.py
KAVYANSHTYAGI May 16, 2025
b792073
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2025
63c922d
linter
Borda May 16, 2025
22dc0ab
Merge branch 'Lightning-AI:master' into sigterm-deadlock
KAVYANSHTYAGI May 17, 2025
71189de
Update training_epoch_loop.py
KAVYANSHTYAGI May 17, 2025
ec210cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 17, 2025
cb25184
Merge branch 'Lightning-AI:master' into sigterm-deadlock
KAVYANSHTYAGI May 22, 2025
81b3d24
Update training_epoch_loop.py
KAVYANSHTYAGI May 22, 2025
67a3b57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 22, 2025
79b39db
Merge branch 'master' into sigterm-deadlock
KAVYANSHTYAGI May 27, 2025
d1ab68f
update + chlog
Borda May 28, 2025
7293e6e
Apply suggestions from code review
Borda May 28, 2025
857637a
linting
Borda May 28, 2025
36e9ecf
type
Borda May 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))


- Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks ([#20825](https://github.com/Lightning-AI/pytorch-lightning/issues/20825))


---

## [2.5.1] - 2025-03-18
Expand Down
25 changes: 25 additions & 0 deletions src/lightning/pytorch/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import math
from collections import OrderedDict
from dataclasses import dataclass
from typing import Any, Optional, Union

import torch
import torch.distributed as dist
from typing_extensions import override

import lightning.pytorch as pl
Expand Down Expand Up @@ -249,6 +252,21 @@ def _on_before_fetch(self) -> None:
def _on_after_fetch(self) -> None:
self.trainer.profiler.stop(f"[{self.__class__.__name__}].train_dataloader_next")

def _broadcast_sigterm_tensor(self):
try:
sigterm_tensor = torch.tensor(
[1 if getattr(self.trainer, "received_sigterm", False) else 0],
device=self.trainer.strategy.root_device,
)
dist.broadcast(sigterm_tensor, src=0)
except Exception:
sigterm_tensor = torch.tensor([0], device=self.trainer.strategy.root_device)

if sigterm_tensor.item() == 1:
with contextlib.suppress(Exception):
dist.barrier() # prevent deadlocks by syncing all ranks before exit
raise SIGTERMException()

def advance(self, data_fetcher: _DataFetcher) -> None:
"""Runs a single training batch.

Expand All @@ -272,6 +290,13 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
# we are going to train first so the val loop does not need to restart
self.val_loop.restarting = False

# =====================================================================

if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
self._broadcast_sigterm_tensor()

# =====================================================================

if using_dataloader_iter := isinstance(data_fetcher, _DataLoaderIterDataFetcher):
dataloader_iter = next(data_fetcher)
# hook's batch_idx and dataloader_idx arguments correctness cannot be guaranteed in this setting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None:
model_checkpoint = LitModelCheckpoint(model_registry=self.trainer._model_registry)
else:
rank_zero_info(
"Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable"
" `LitModelCheckpoint` for automatic upload to the Lightning model registry."
"💡 Tip: For seamless cloud uploads and versioning,"
" try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint,"
" which syncs automatically with the Lightning model registry."
)
model_checkpoint = ModelCheckpoint()
self.trainer.callbacks.append(model_checkpoint)
Expand Down
11 changes: 9 additions & 2 deletions src/lightning/pytorch/trainer/connectors/signal_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from types import FrameType
from typing import Any, Callable, Union

import torch
import torch.distributed as dist

import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _IS_WINDOWS
Expand Down Expand Up @@ -104,12 +107,16 @@ def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:

def _sigterm_notifier_fn(self, signum: _SIGNUM, _: FrameType) -> None:
log.info(rank_prefixed_message(f"Received SIGTERM: {signum}", self.trainer.local_rank))
# subprocesses killing the parent process is not supported, only the parent (rank 0) does it
if not self.received_sigterm:
# send the same signal to the subprocesses
launcher = self.trainer.strategy.launcher
if launcher is not None:
launcher.kill(signum)

# New broadcast logic
if dist.is_available() and dist.is_initialized() and self.trainer.world_size > 1:
sigterm_tensor = torch.tensor([1], device=self.trainer.strategy.root_device)
dist.broadcast(sigterm_tensor, src=0)

self.received_sigterm = True

def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
Expand Down
80 changes: 80 additions & 0 deletions tests/tests_pytorch/trainer/test_ddp_sigterm_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import signal
import time

import pytest
import torch
import torch.multiprocessing as mp

from lightning.pytorch import LightningModule, Trainer, seed_everything
from lightning.pytorch.demos.boring_classes import BoringDataModule
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.utilities.exceptions import SIGTERMException

# Skip the test if DDP or multiple devices are not available

pytestmark = pytest.mark.skipif(
not torch.distributed.is_available() or torch.cuda.device_count() < 2,
reason="Requires torch.distributed and at least 2 CUDA devices",
)


class DummyModel(LightningModule):
def training_step(self, batch, batch_idx):
# Simulate SIGTERM in rank 0 at batch 2
if self.trainer.global_rank == 0 and batch_idx == 2:
time.sleep(3) # Let other ranks proceed to the next batch
os.kill(os.getpid(), signal.SIGTERM)
return super().training_step(batch, batch_idx)


def run_ddp_sigterm(rank, world_size, tmpdir):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

seed_everything(42)

torch.cuda.set_device(rank) if torch.cuda.is_available() else None

model = DummyModel()
datamodule = BoringDataModule()

trainer = Trainer(
accelerator="cuda" if torch.cuda.is_available() else "cpu",
strategy=DDPStrategy(find_unused_parameters=False),
devices=world_size,
num_nodes=1,
max_epochs=3,
default_root_dir=tmpdir,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
logger=False,
)

try:
trainer.fit(model, datamodule=datamodule)
except SIGTERMException:
# Test passed: SIGTERM was properly raised and caught
print(f"[Rank {rank}] Caught SIGTERMException successfully.")
except Exception as e:
pytest.fail(f"[Rank {rank}] Unexpected exception: {e}")


def test_ddp_sigterm_handling(tmp_path):
world_size = 2
mp.spawn(run_ddp_sigterm, args=(world_size, tmp_path), nprocs=world_size, join=True)


@pytest.mark.skipif(
not torch.distributed.is_available(),
reason="Requires torch.distributed",
)
@pytest.mark.skipif(
torch.cuda.is_available() and torch.cuda.device_count() < 2,
reason="Requires >=2 CUDA devices or use CPU",
)
def test_sigterm_handling_ddp(tmp_path):
test_ddp_sigterm_handling(tmp_path)
Loading