Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
-


- Fixed early stopping triggering extra validation runs after reaching `min_epochs` or `min_steps` ([#16719](https://github.com/Lightning-AI/lightning/pull/16719))


## [1.9.1] - 2023-02-10

### Fixed
Expand Down
10 changes: 6 additions & 4 deletions src/lightning/pytorch/loops/epoch/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ def done(self) -> bool:
if self.trainer.should_stop:
# early stopping
min_epochs = self.trainer.fit_loop.min_epochs
should_stop_early = self.trainer.fit_loop._should_stop_early
if not should_stop_early:
can_stop_early = self.trainer.fit_loop._can_stop_early
if not can_stop_early:
self._warning_cache.info(
f"Trainer was signaled to stop but the required `min_epochs={min_epochs!r}` or"
f" `min_steps={self.min_steps!r}` has not been met. Training will continue..."
)
return should_stop_early
return can_stop_early

return False

Expand Down Expand Up @@ -389,7 +389,9 @@ def _should_check_val_fx(self) -> bool:
if is_last_batch and is_infinite_dataset:
return True

if self.trainer.should_stop:
if self.trainer.should_stop and self.trainer.fit_loop._can_stop_early:
# allow validation if requesting to stop early through `Trainer.should_stop` (e.g. by early stopping)
# and when the loop allows to stop (min_epochs/steps met)
return True

# TODO: let training/eval loop handle logic around limit_*_batches and val_check_batch
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _results(self) -> _ResultCollection:
raise RuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")

@property
def _should_stop_early(self) -> bool:
def _can_stop_early(self) -> bool:
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
return met_min_epochs and met_min_steps
Expand Down Expand Up @@ -170,7 +170,7 @@ def done(self) -> bool:
rank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")
return True

if self.trainer.should_stop and self._should_stop_early:
if self.trainer.should_stop and self._can_stop_early:
rank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")
return True

Expand Down
29 changes: 27 additions & 2 deletions tests/tests_pytorch/loops/epoch/test_training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from unittest.mock import patch
from unittest.mock import Mock, patch

import pytest

Expand Down Expand Up @@ -75,4 +75,29 @@ def test_should_stop_early_stopping_conditions_not_met(
assert trainer.fit_loop.epoch_loop.done is epoch_loop_done

assert (message in caplog.text) is raise_info_msg
assert trainer.fit_loop._should_stop_early is early_stop
assert trainer.fit_loop._can_stop_early is early_stop


@pytest.mark.parametrize("min_epochs,min_steps,val_count", [(3, None, 3), (None, 3, 2)])
def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, tmp_path):
"""Regression test for issue #15708.

Test that the request for `should_stop=True` only triggers validation when Trainer is allowed to stop
(min_epochs/steps is satisfied).
"""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmp_path,
num_sanity_val_steps=0,
limit_val_batches=2,
limit_train_batches=2,
max_epochs=3,
min_epochs=min_epochs,
min_steps=min_steps,
enable_model_summary=False,
enable_checkpointing=False,
)
trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached
trainer.fit_loop.epoch_loop.val_loop.run = Mock()
trainer.fit(model)
assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count
2 changes: 1 addition & 1 deletion tests/tests_pytorch/loops/test_training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,4 +223,4 @@ def test_should_stop_early_stopping_conditions_met(
assert trainer.fit_loop.done is fit_loop_done

assert (message in caplog.text) is raise_debug_msg
assert trainer.fit_loop._should_stop_early is early_stop
assert trainer.fit_loop._can_stop_early is early_stop