Skip to content

Commit 7170a91

Browse files
Queuecumberawaelchli
authored andcommitted
Support Slurm Autorequeue for Array Jobs (#15040)
Signed-off-by: Max Ehrlich <[email protected]> Co-authored-by: awaelchli <[email protected]>
1 parent c887552 commit 7170a91

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added support for requeueing slurm array jobs ([#15022](https://github.com/Lightning-AI/lightning/issues/15022))
13+
14+
1215
- Added native AMP support for `ddp_fork` (and associated alias strategies) with CUDA GPUs ([#14983](https://github.com/Lightning-AI/lightning/pull/14983))
1316

1417

src/pytorch_lightning/trainer/connectors/signal_connector.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,13 @@ def slurm_sigusr_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
7878

7979
if self.trainer.is_global_zero:
8080
# find job id
81-
job_id = os.environ["SLURM_JOB_ID"]
81+
array_job_id = os.getenv("SLURM_ARRAY_JOB_ID")
82+
if array_job_id is not None:
83+
array_task_id = os.environ["SLURM_ARRAY_TASK_ID"]
84+
job_id = f"{array_job_id}_{array_task_id}"
85+
else:
86+
job_id = os.environ["SLURM_JOB_ID"]
87+
8288
cmd = ["scontrol", "requeue", job_id]
8389

8490
# requeue job

tests/tests_pytorch/trainer/connectors/test_signal_connector.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,32 @@ def test_auto_requeue_custom_signal_flag(auto_requeue, requeue_signal):
100100
connector.teardown()
101101

102102

103+
@RunIf(skip_windows=True)
104+
@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call")
105+
@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock())
106+
@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12345"})
107+
def test_auto_requeue_job(call_mock):
108+
call_mock.return_value = 0
109+
trainer = Trainer(plugins=[SLURMEnvironment()])
110+
connector = SignalConnector(trainer)
111+
connector.slurm_sigusr_handler_fn(None, None)
112+
call_mock.assert_called_once_with(["scontrol", "requeue", "12345"])
113+
connector.teardown()
114+
115+
116+
@RunIf(skip_windows=True)
117+
@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call")
118+
@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock())
119+
@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12346", "SLURM_ARRAY_JOB_ID": "12345", "SLURM_ARRAY_TASK_ID": "2"})
120+
def test_auto_requeue_array_job(call_mock):
121+
call_mock.return_value = 0
122+
trainer = Trainer(plugins=[SLURMEnvironment()])
123+
connector = SignalConnector(trainer)
124+
connector.slurm_sigusr_handler_fn(None, None)
125+
call_mock.assert_called_once_with(["scontrol", "requeue", "12345_2"])
126+
connector.teardown()
127+
128+
103129
def _registering_signals():
104130
trainer = Trainer()
105131
trainer._signal_connector.register_signal_handlers()

0 commit comments

Comments
 (0)