@@ -441,6 +441,43 @@ def test_if_lr_finder_callback_already_configured():
441
441
trainer .tune (model )
442
442
443
443
444
+ def test_lr_finder_callback_restarting (tmpdir ):
445
+ """Test that `LearningRateFinder` does not set restarting=True when loading checkpoint."""
446
+
447
+ class MyBoringModel (BoringModel ):
448
+ def __init__ (self ):
449
+ super ().__init__ ()
450
+ self .learning_rate = 0.123
451
+
452
+ def configure_optimizers (self ):
453
+ return torch .optim .SGD (self .parameters (), lr = self .learning_rate )
454
+
455
+ class CustomLearningRateFinder (LearningRateFinder ):
456
+ milestones = (1 ,)
457
+
458
+ def lr_find (self , trainer , pl_module ) -> None :
459
+ super ().lr_find (trainer , pl_module )
460
+ assert not trainer .fit_loop .restarting
461
+
462
+ def on_train_epoch_start (self , trainer , pl_module ):
463
+ if trainer .current_epoch in self .milestones or trainer .current_epoch == 0 :
464
+ self .lr_find (trainer , pl_module )
465
+
466
+ model = MyBoringModel ()
467
+ trainer = Trainer (
468
+ default_root_dir = tmpdir ,
469
+ max_epochs = 3 ,
470
+ callbacks = [CustomLearningRateFinder (early_stop_threshold = None , update_attr = True )],
471
+ limit_train_batches = 10 ,
472
+ limit_val_batches = 0 ,
473
+ limit_test_batches = 00 ,
474
+ num_sanity_val_steps = 0 ,
475
+ enable_model_summary = False ,
476
+ )
477
+
478
+ trainer .fit (model )
479
+
480
+
444
481
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
445
482
@RunIf (standalone = True )
446
483
def test_lr_finder_with_ddp (tmpdir ):
0 commit comments