Skip to content

Commit 3fa4308

Browse files
committed
Add test cases for the requeue scontrol calls
Signed-off-by: Max Ehrlich <[email protected]>
1 parent 571059c commit 3fa4308

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/tests_pytorch/trainer/connectors/test_signal_connector.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,36 @@ def test_auto_requeue_custom_signal_flag(auto_requeue, requeue_signal):
100100
connector.teardown()
101101

102102

103+
@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call", mock.MagicMock(return_value=0))
104+
@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock())
105+
@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12345"})
106+
def test_auto_requeue_job():
107+
from pytorch_lightning.trainer.connectors.signal_connector import call
108+
109+
trainer = Trainer(plugins=[SLURMEnvironment()])
110+
connector = SignalConnector(trainer)
111+
connector.slurm_sigusr_handler_fn(signal.SIGUSR1, None)
112+
113+
assert call.call_args_list[0].args[0] == ["scontrol", "requeue", "12345"]
114+
115+
connector.teardown()
116+
117+
118+
@mock.patch("pytorch_lightning.trainer.connectors.signal_connector.call", mock.MagicMock(return_value=0))
119+
@mock.patch("pytorch_lightning.trainer.Trainer.save_checkpoint", mock.MagicMock())
120+
@mock.patch.dict(os.environ, {"SLURM_JOB_ID": "12345", "SLURM_ARRAY_JOB_ID": "12345", "SLURM_ARRAY_TASK_ID": "1"})
121+
def test_auto_requeue_array_job():
122+
from pytorch_lightning.trainer.connectors.signal_connector import call
123+
124+
trainer = Trainer(plugins=[SLURMEnvironment()])
125+
connector = SignalConnector(trainer)
126+
connector.slurm_sigusr_handler_fn(signal.SIGUSR1, None)
127+
128+
assert call.call_args_list[0].args[0] == ["scontrol", "requeue", "12345_1"]
129+
130+
connector.teardown()
131+
132+
103133
def _registering_signals():
104134
trainer = Trainer()
105135
trainer._signal_connector.register_signal_handlers()

0 commit comments

Comments
 (0)