-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks #20825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+119
−4
Merged
Changes from 23 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
6a1bbf1
Update signal_connector.py
KAVYANSHTYAGI 2761ad8
Update training_epoch_loop.py
KAVYANSHTYAGI 93c3e69
Create test_ddp_sigterm_handling.py
KAVYANSHTYAGI b7cef51
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2fc5178
Update training_epoch_loop.py
KAVYANSHTYAGI c8c9523
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] ebbe682
Update test_ddp_sigterm_handling.py
KAVYANSHTYAGI f327aa7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f50b3a9
Update training_epoch_loop.py
KAVYANSHTYAGI 49e2fab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5600bff
Update training_epoch_loop.py
KAVYANSHTYAGI 873a792
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4a275da
Update test_ddp_sigterm_handling.py
KAVYANSHTYAGI b792073
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 63c922d
linter
Borda 22dc0ab
Merge branch 'Lightning-AI:master' into sigterm-deadlock
KAVYANSHTYAGI 71189de
Update training_epoch_loop.py
KAVYANSHTYAGI ec210cb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cb25184
Merge branch 'Lightning-AI:master' into sigterm-deadlock
KAVYANSHTYAGI 81b3d24
Update training_epoch_loop.py
KAVYANSHTYAGI 67a3b57
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 79b39db
Merge branch 'master' into sigterm-deadlock
KAVYANSHTYAGI d1ab68f
update + chlog
Borda 7293e6e
Apply suggestions from code review
Borda 857637a
linting
Borda 36e9ecf
type
Borda File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
import signal | ||
import time | ||
|
||
import pytest | ||
import torch | ||
import torch.multiprocessing as mp | ||
|
||
from lightning.pytorch import LightningModule, Trainer, seed_everything | ||
from lightning.pytorch.demos.boring_classes import BoringDataModule | ||
from lightning.pytorch.strategies.ddp import DDPStrategy | ||
from lightning.pytorch.utilities.exceptions import SIGTERMException | ||
|
||
# Skip the test if DDP or multiple devices are not available | ||
|
||
pytestmark = pytest.mark.skipif( | ||
not torch.distributed.is_available() or torch.cuda.device_count() < 2, | ||
reason="Requires torch.distributed and at least 2 CUDA devices", | ||
) | ||
|
||
|
||
class DummyModel(LightningModule): | ||
def training_step(self, batch, batch_idx): | ||
# Simulate SIGTERM in rank 0 at batch 2 | ||
if self.trainer.global_rank == 0 and batch_idx == 2: | ||
time.sleep(3) # Let other ranks proceed to the next batch | ||
os.kill(os.getpid(), signal.SIGTERM) | ||
return super().training_step(batch, batch_idx) | ||
|
||
|
||
def run_ddp_sigterm(rank, world_size, tmpdir): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
os.environ["RANK"] = str(rank) | ||
os.environ["WORLD_SIZE"] = str(world_size) | ||
|
||
seed_everything(42) | ||
|
||
torch.cuda.set_device(rank) if torch.cuda.is_available() else None | ||
|
||
model = DummyModel() | ||
datamodule = BoringDataModule() | ||
|
||
trainer = Trainer( | ||
accelerator="cuda" if torch.cuda.is_available() else "cpu", | ||
strategy=DDPStrategy(find_unused_parameters=False), | ||
devices=world_size, | ||
num_nodes=1, | ||
max_epochs=3, | ||
default_root_dir=tmpdir, | ||
enable_checkpointing=False, | ||
enable_progress_bar=False, | ||
enable_model_summary=False, | ||
logger=False, | ||
) | ||
|
||
try: | ||
trainer.fit(model, datamodule=datamodule) | ||
except SIGTERMException: | ||
# Test passed: SIGTERM was properly raised and caught | ||
print(f"[Rank {rank}] Caught SIGTERMException successfully.") | ||
except Exception as e: | ||
pytest.fail(f"[Rank {rank}] Unexpected exception: {e}") | ||
|
||
|
||
def test_ddp_sigterm_handling(tmp_path): | ||
world_size = 2 | ||
mp.spawn(run_ddp_sigterm, args=(world_size, tmp_path), nprocs=world_size, join=True) | ||
|
||
|
||
@pytest.mark.skipif( | ||
not torch.distributed.is_available(), | ||
reason="Requires torch.distributed", | ||
) | ||
@pytest.mark.skipif( | ||
torch.cuda.is_available() and torch.cuda.device_count() < 2, | ||
reason="Requires >=2 CUDA devices or use CPU", | ||
) | ||
def test_sigterm_handling_ddp(tmp_path): | ||
test_ddp_sigterm_handling(tmp_path) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.