Skip to content

Commit 617d03d

Browse files
SkafteNickiBordadeependujha
authored andcommitted
Warning on eval module (#21146)
* add warning * tests * changelog * Apply suggestions from code review --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Deependu <[email protected]> (cherry picked from commit 90eba3f)
1 parent bcb5065 commit 617d03d

File tree

3 files changed

+46
-0
lines changed

3 files changed

+46
-0
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
88

99
## [unReleased] - 2025-09-DD
1010

11+
- Added `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))
12+
1113
### Changed
1214

1315
-

src/lightning/pytorch/loops/fit_loop.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,9 @@ def on_run_start(self) -> None:
414414
self.epoch_loop.val_loop.setup_data()
415415
trainer.training = True
416416

417+
# Check for modules in eval mode at training start
418+
self._warn_if_modules_in_eval_mode()
419+
417420
call._call_callback_hooks(trainer, "on_train_start")
418421
call._call_lightning_module_hook(trainer, "on_train_start")
419422
call._call_strategy_hook(trainer, "on_train_start")
@@ -515,6 +518,19 @@ def on_load_checkpoint(self, state_dict: dict) -> None:
515518
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
516519
super().on_load_checkpoint(state_dict)
517520

521+
def _warn_if_modules_in_eval_mode(self) -> None:
522+
"""Warn if any modules are in eval mode at the start of training."""
523+
model = self.trainer.lightning_module
524+
eval_modules = [name for name, module in model.named_modules() if not module.training]
525+
526+
if eval_modules:
527+
rank_zero_warn(
528+
f"Found {len(eval_modules)} module(s) in eval mode at the start of training."
529+
" This may lead to unexpected behavior during training. If this is intentional,"
530+
" you can ignore this warning.",
531+
category=PossibleUserWarning,
532+
)
533+
518534
def _should_accumulate(self) -> bool:
519535
"""Whether the gradients should be accumulated."""
520536
return self.epoch_loop._should_accumulate()

tests/tests_pytorch/loops/test_training_loop.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
# limitations under the License.
1414
import itertools
1515
import logging
16+
import warnings
1617
from unittest.mock import Mock
1718

1819
import pytest
1920
import torch
2021
from torch.utils.data import DataLoader
2122

23+
from lightning.fabric.utilities.warnings import PossibleUserWarning
2224
from lightning.pytorch import Trainer, seed_everything
2325
from lightning.pytorch.demos.boring_classes import BoringModel
2426
from lightning.pytorch.loops import _FitLoop
@@ -277,3 +279,29 @@ def __iter__(self):
277279

278280
# assert progress bar callback uses correct total steps
279281
assert pbar.train_progress_bar.total == max_steps
282+
283+
284+
@pytest.mark.parametrize("warn", [True, False])
285+
def test_eval_mode_warning(tmp_path, warn):
286+
"""Test that a warning is raised if any module is in eval mode at the start of training."""
287+
model = BoringModel()
288+
if warn:
289+
model.some_eval_module = torch.nn.Linear(32, 16)
290+
model.some_eval_module.eval()
291+
292+
trainer = Trainer(
293+
default_root_dir=tmp_path,
294+
max_epochs=1,
295+
)
296+
297+
if warn:
298+
with pytest.warns(PossibleUserWarning):
299+
trainer.fit(model)
300+
else:
301+
with warnings.catch_warnings(record=True) as warning_list:
302+
warnings.simplefilter("always")
303+
trainer.fit(model)
304+
eval_warnings = [
305+
w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
306+
]
307+
assert len(eval_warnings) == 0, "Expected no eval mode warnings"

0 commit comments

Comments
 (0)