Skip to content

Commit 6db6431

Browse files
awaelchliBorda
andcommitted
Fix min-epochs and early-stopping triggering too many validation runs (#16719)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 73cd956 commit 6db6431

File tree

5 files changed

+35
-7
lines changed

5 files changed

+35
-7
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020
### Fixed
2121

2222
- Fixed an attribute error and improved input validation for invalid strategy types being passed to Trainer ([#16693](https://github.com/Lightning-AI/lightning/pull/16693))
23+
- Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719))
2324

2425

2526
## [1.9.1] - 2023-02-10

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def done(self) -> bool:
107107
if self.trainer.should_stop:
108108
# early stopping
109109
min_epochs = self.trainer.fit_loop.min_epochs
110-
should_stop_early = self.trainer.fit_loop._should_stop_early
110+
should_stop_early = self.trainer.fit_loop._can_stop_early
111111
if not should_stop_early:
112112
self._warning_cache.info(
113113
f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or"
@@ -482,7 +482,9 @@ def _should_check_val_fx(self) -> bool:
482482
if is_last_batch and is_infinite_dataset:
483483
return True
484484

485-
if self.trainer.should_stop:
485+
if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early:
486+
# allow validation if requesting to stop early through `Trainer.should_stop` (e.g. by early stopping)
487+
# and when the loop allows to stop (min_epochs/steps met)
486488
return True
487489

488490
# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch

src/pytorch_lightning/loops/fit_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def _results(self) -> _ResultCollection:
147147
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")
148148

149149
@property
150-
def _should_stop_early(self) -> bool:
150+
def _can_stop_early(self) -> bool:
151151
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
152152
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
153153
return met_min_epochs and met_min_steps
@@ -175,7 +175,7 @@ def done(self) -> bool:
175175
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
176176
return True
177177

178-
if self.trainer.should_stop and self._should_stop_early:
178+
if self.trainer.should_stop and self._can_stop_early:
179179
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
180180
return True
181181

tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import logging
15-
from unittest.mock import patch
15+
from unittest.mock import Mock, patch
1616

1717
import pytest
1818

@@ -217,4 +217,29 @@ def test_should_stop_early_stopping_conditions_not_met(
217217
assert trainer.fit_loop.epoch_loop.done is epoch_loop_done
218218

219219
assert (message in caplog.text) is raise_info_msg
220-
assert trainer.fit_loop._should_stop_early is early_stop
220+
assert trainer.fit_loop._can_stop_early is early_stop
221+
222+
223+
@pytest.mark.parametrize("min_epochs,min_steps,val_count", [(3, None, 3), (None, 3, 2)])
224+
def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, tmp_path):
225+
"""Regression test for issue #15708.
226+
227+
Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop
228+
(min_epochs/steps is satisfied).
229+
"""
230+
model = BoringModel()
231+
trainer = Trainer(
232+
default_root_dir=tmp_path,
233+
num_sanity_val_steps=0,
234+
limit_val_batches=2,
235+
limit_train_batches=2,
236+
max_epochs=3,
237+
min_epochs=min_epochs,
238+
min_steps=min_steps,
239+
enable_model_summary=False,
240+
enable_checkpointing=False,
241+
)
242+
trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached
243+
trainer.fit_loop.epoch_loop.val_loop.run = Mock()
244+
trainer.fit(model)
245+
assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,4 +228,4 @@ def test_should_stop_early_stopping_conditions_met(
228228
assert trainer.fit_loop.done is fit_loop_done
229229

230230
assert (message in caplog.text) is raise_debug_msg
231-
assert trainer.fit_loop._should_stop_early is early_stop
231+
assert trainer.fit_loop._can_stop_early is early_stop

0 commit comments

Comments
 (0)