Skip to content

Commit f142ce0

Browse files
justusschockBorda
authored andcommitted
Fix restarting attribute for lr finder (#15620)
(cherry picked from commit 15184c6)
1 parent 3ccfb80 commit f142ce0

File tree

4 files changed

+64
-8
lines changed

4 files changed

+64
-8
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77
## [1.8.4] - 2022-12-08
88

9+
### Changed
10+
911
- Direct support for compiled models ([#15922](https://github.com/Lightning-AI/lightning/pull/15922))
12+
13+
### Fixed
14+
1015
- Fixed issue with unsupported torch.inference_mode() on hpu backends ([#15918](https://github.com/Lightning-AI/lightning/pull/15918))
11-
- Fix LRScheduler import for PyTorch 2.0 ([#15940](https://github.com/Lightning-AI/lightning/pull/15940))
16+
- Fixed LRScheduler import for PyTorch 2.0 ([#15940](https://github.com/Lightning-AI/lightning/pull/15940))
17+
- Fixed `fit_loop.restarting` to be `False` for lr finder ([#15620](https://github.com/Lightning-AI/lightning/pull/15620))
1218

1319

1420
## [1.8.3] - 2022-11-22

src/pytorch_lightning/callbacks/lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
max_lr: float = 1,
8686
num_training_steps: int = 100,
8787
mode: str = "exponential",
88-
early_stop_threshold: float = 4.0,
88+
early_stop_threshold: Optional[float] = 4.0,
8989
update_attr: bool = False,
9090
) -> None:
9191
mode = mode.lower()

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def lr_find(
203203
max_lr: float = 1,
204204
num_training: int = 100,
205205
mode: str = "exponential",
206-
early_stop_threshold: float = 4.0,
206+
early_stop_threshold: Optional[float] = 4.0,
207207
update_attr: bool = False,
208208
) -> Optional[_LRFinder]:
209209
"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""
@@ -219,6 +219,8 @@ def lr_find(
219219
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
220220
trainer.save_checkpoint(ckpt_path)
221221

222+
start_steps = trainer.global_step
223+
222224
# Arguments we adjust during the lr finder, save for restoring
223225
params = __lr_finder_dump_params(trainer)
224226

@@ -239,7 +241,7 @@ def lr_find(
239241
_try_loop_run(trainer, params)
240242

241243
# Prompt if we stopped early
242-
if trainer.global_step != num_training:
244+
if trainer.global_step != num_training + start_steps:
243245
log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
244246

245247
# Transfer results from callback to lr finder object
@@ -263,6 +265,7 @@ def lr_find(
263265
# Restore initial state of model
264266
trainer._checkpoint_connector.restore(ckpt_path)
265267
trainer.strategy.remove_checkpoint(ckpt_path)
268+
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
266269

267270
return lr_finder
268271

@@ -282,7 +285,7 @@ def __lr_finder_dump_params(trainer: "pl.Trainer") -> Dict[str, Any]:
282285
}
283286

284287

285-
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: float) -> None:
288+
def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_stop_threshold: Optional[float]) -> None:
286289
from pytorch_lightning.loggers.logger import DummyLogger
287290

288291
trainer.strategy.lr_scheduler_configs = []
@@ -293,8 +296,8 @@ def __lr_finder_reset_params(trainer: "pl.Trainer", num_training: int, early_sto
293296
trainer.callbacks = [_LRCallback(num_training, early_stop_threshold, progress_bar_refresh_rate=1)]
294297
# No logging
295298
trainer.logger = DummyLogger() if trainer.logger is not None else None
296-
# Max step set to number of iterations
297-
trainer.fit_loop.max_steps = num_training
299+
# Max step set to number of iterations starting at current number of iterations
300+
trainer.fit_loop.max_steps = num_training + trainer.global_step
298301
trainer.limit_val_batches = num_training
299302

300303

@@ -332,7 +335,7 @@ class _LRCallback(Callback):
332335
def __init__(
333336
self,
334337
num_training: int,
335-
early_stop_threshold: float = 4.0,
338+
early_stop_threshold: Optional[float] = 4.0,
336339
progress_bar_refresh_rate: int = 0,
337340
beta: float = 0.98,
338341
):

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,3 +438,50 @@ def test_if_lr_finder_callback_already_configured():
438438

439439
with pytest.raises(MisconfigurationException, match="Trainer is already configured with a .* callback"):
440440
trainer.tune(model)
441+
442+
443+
def test_lr_finder_callback_restarting(tmpdir):
444+
"""Test that `LearningRateFinder` does not set restarting=True when loading checkpoint."""
445+
446+
num_lr_steps = 100
447+
448+
class MyBoringModel(BoringModel):
449+
def __init__(self):
450+
super().__init__()
451+
self.learning_rate = 0.123
452+
453+
def on_train_batch_start(self, batch, batch_idx):
454+
if getattr(self, "_expected_max_steps", None) is not None:
455+
assert self.trainer.fit_loop.max_steps == self._expected_max_steps
456+
457+
def configure_optimizers(self):
458+
return torch.optim.SGD(self.parameters(), lr=self.learning_rate)
459+
460+
class CustomLearningRateFinder(LearningRateFinder):
461+
milestones = (1,)
462+
463+
def lr_find(self, trainer, pl_module) -> None:
464+
pl_module._expected_max_steps = trainer.global_step + self._num_training_steps
465+
super().lr_find(trainer, pl_module)
466+
pl_module._expected_max_steps = None
467+
assert not trainer.fit_loop.restarting
468+
469+
def on_train_epoch_start(self, trainer, pl_module):
470+
if trainer.current_epoch in self.milestones or trainer.current_epoch == 0:
471+
self.lr_find(trainer, pl_module)
472+
473+
model = MyBoringModel()
474+
trainer = Trainer(
475+
default_root_dir=tmpdir,
476+
max_epochs=3,
477+
callbacks=[
478+
CustomLearningRateFinder(early_stop_threshold=None, update_attr=True, num_training_steps=num_lr_steps)
479+
],
480+
limit_train_batches=10,
481+
limit_val_batches=0,
482+
limit_test_batches=0,
483+
num_sanity_val_steps=0,
484+
enable_model_summary=False,
485+
)
486+
487+
trainer.fit(model)

0 commit comments

Comments
 (0)