Skip to content

Commit f880c73

Browse files
committed
Fix restarting attribute for lr finder
1 parent 08d14ec commit f880c73

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ def lr_find(
270270
# Restore initial state of model
271271
trainer._checkpoint_connector.restore(ckpt_path)
272272
trainer.strategy.remove_checkpoint(ckpt_path)
273+
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
273274

274275
return lr_finder
275276

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,43 @@ def test_if_lr_finder_callback_already_configured():
441441
trainer.tune(model)
442442

443443

444+
def test_lr_finder_callback_restarting(tmpdir):
445+
"""Test that `LearningRateFinder` does not set restarting=True when loading checkpoint."""
446+
447+
class MyBoringModel(BoringModel):
448+
def __init__(self):
449+
super().__init__()
450+
self.learning_rate = 0.123
451+
452+
def configure_optimizers(self):
453+
return torch.optim.SGD(self.parameters(), lr=self.learning_rate)
454+
455+
class CustomLearningRateFinder(LearningRateFinder):
456+
milestones = (1,)
457+
458+
def lr_find(self, trainer, pl_module) -> None:
459+
super().lr_find(trainer, pl_module)
460+
assert not trainer.fit_loop.restarting
461+
462+
def on_train_epoch_start(self, trainer, pl_module):
463+
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
464+
self.lr_find(trainer, pl_module)
465+
466+
model = MyBoringModel()
467+
trainer = Trainer(
468+
default_root_dir=tmpdir,
469+
max_epochs=3,
470+
callbacks=[CustomLearningRateFinder(early_stop_threshold=None, update_attr=True)],
471+
limit_train_batches=10,
472+
limit_val_batches=0,
473+
limit_test_batches=00,
474+
num_sanity_val_steps=0,
475+
enable_model_summary=False,
476+
)
477+
478+
trainer.fit(model)
479+
480+
444481
@mock.patch.dict(os.environ, os.environ.copy(), clear=True)
445482
@RunIf(standalone=True)
446483
def test_lr_finder_with_ddp(tmpdir):

0 commit comments

Comments
 (0)