@@ -100,6 +100,36 @@ def test_auto_requeue_custom_signal_flag(auto_requeue, requeue_signal):
100
100
connector .teardown ()
101
101
102
102
103
+ @mock .patch ("pytorch_lightning.trainer.connectors.signal_connector.call" , mock .MagicMock (return_value = 0 ))
104
+ @mock .patch ("pytorch_lightning.trainer.Trainer.save_checkpoint" , mock .MagicMock ())
105
+ @mock .patch .dict (os .environ , {"SLURM_JOB_ID" : "12345" })
106
+ def test_auto_requeue_job ():
107
+ from pytorch_lightning .trainer .connectors .signal_connector import call
108
+
109
+ trainer = Trainer (plugins = [SLURMEnvironment ()])
110
+ connector = SignalConnector (trainer )
111
+ connector .slurm_sigusr_handler_fn (signal .SIGUSR1 , None )
112
+
113
+ assert call .call_args_list [0 ].args [0 ] == ["scontrol" , "requeue" , "12345" ]
114
+
115
+ connector .teardown ()
116
+
117
+
118
+ @mock .patch ("pytorch_lightning.trainer.connectors.signal_connector.call" , mock .MagicMock (return_value = 0 ))
119
+ @mock .patch ("pytorch_lightning.trainer.Trainer.save_checkpoint" , mock .MagicMock ())
120
+ @mock .patch .dict (os .environ , {"SLURM_JOB_ID" : "12345" , "SLURM_ARRAY_JOB_ID" : "12345" , "SLURM_ARRAY_TASK_ID" : "1" })
121
+ def test_auto_requeue_array_job ():
122
+ from pytorch_lightning .trainer .connectors .signal_connector import call
123
+
124
+ trainer = Trainer (plugins = [SLURMEnvironment ()])
125
+ connector = SignalConnector (trainer )
126
+ connector .slurm_sigusr_handler_fn (signal .SIGUSR1 , None )
127
+
128
+ assert call .call_args_list [0 ].args [0 ] == ["scontrol" , "requeue" , "12345_1" ]
129
+
130
+ connector .teardown ()
131
+
132
+
103
133
def _registering_signals ():
104
134
trainer = Trainer ()
105
135
trainer ._signal_connector .register_signal_handlers ()
0 commit comments