@@ -203,7 +203,7 @@ def lr_find(
203
203
max_lr : float = 1 ,
204
204
num_training : int = 100 ,
205
205
mode : str = "exponential" ,
206
- early_stop_threshold : float = 4.0 ,
206
+ early_stop_threshold : Optional [ float ] = 4.0 ,
207
207
update_attr : bool = False ,
208
208
) -> Optional [_LRFinder ]:
209
209
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
@@ -219,6 +219,8 @@ def lr_find(
219
219
ckpt_path = os .path .join (trainer .default_root_dir , f".lr_find_{ uuid .uuid4 ()} .ckpt" )
220
220
trainer .save_checkpoint (ckpt_path )
221
221
222
+ start_steps = trainer .global_step
223
+
222
224
# Arguments we adjust during the lr finder, save for restoring
223
225
params = __lr_finder_dump_params (trainer )
224
226
@@ -239,7 +241,7 @@ def lr_find(
239
241
_try_loop_run (trainer , params )
240
242
241
243
# Prompt if we stopped early
242
- if trainer .global_step != num_training :
244
+ if trainer .global_step != num_training + start_steps :
243
245
log .info (f"LR finder stopped early after { trainer .global_step } steps due to diverging loss." )
244
246
245
247
# Transfer results from callback to lr finder object
@@ -263,6 +265,7 @@ def lr_find(
263
265
# Restore initial state of model
264
266
trainer ._checkpoint_connector .restore (ckpt_path )
265
267
trainer .strategy .remove_checkpoint (ckpt_path )
268
+ trainer .fit_loop .restarting = False # reset restarting flag as checkpoint restoring sets it to True
266
269
267
270
return lr_finder
268
271
@@ -282,7 +285,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
282
285
}
283
286
284
287
285
- def __lr_finder_reset_params (trainer : "pl.Trainer" , num_training : int , early_stop_threshold : float ) -> None :
288
+ def __lr_finder_reset_params (trainer : "pl.Trainer" , num_training : int , early_stop_threshold : Optional [ float ] ) -> None :
286
289
from pytorch_lightning .loggers .logger import DummyLogger
287
290
288
291
trainer .strategy .lr_scheduler_configs = []
@@ -293,8 +296,8 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
293
296
trainer .callbacks = [_LRCallback (num_training , early_stop_threshold , progress_bar_refresh_rate = 1 )]
294
297
# No logging
295
298
trainer .logger = DummyLogger () if trainer .logger is not None else None
296
- # Max step set to number of iterations
297
- trainer .fit_loop .max_steps = num_training
299
+ # Max step set to number of iterations starting at current number of iterations
300
+ trainer .fit_loop .max_steps = num_training + trainer . global_step
298
301
trainer .limit_val_batches = num_training
299
302
300
303
@@ -332,7 +335,7 @@ class _LRCallback(Callback):
332
335
def __init__ (
333
336
self ,
334
337
num_training : int ,
335
- early_stop_threshold : float = 4.0 ,
338
+ early_stop_threshold : Optional [ float ] = 4.0 ,
336
339
progress_bar_refresh_rate : int = 0 ,
337
340
beta : float = 0.98 ,
338
341
):
0 commit comments