Skip to content

Commit 654c401

Browse files
committed
Fix an issue to avoid the impact of sanity check on reload_dataloaders_every_n_epochs for validation (#13964)
1 parent 9cac16e commit 654c401

File tree

7 files changed

+83
-85
lines changed

7 files changed

+83
-85
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4343
- Fixed not preserving set attributes on `DataLoader` and `BatchSampler` when instantiated inside `*_dataloader` hooks ([#14212](https://github.com/Lightning-AI/lightning/pull/14212))
4444

4545

46+
- Fixed an issue to avoid the impact of sanity check on `reload_dataloaders_every_n_epochs` for validation ([#13964](https://github.com/Lightning-AI/lightning/pull/13964))
47+
48+
4649
## [1.7.1] - 2022-08-09
4750

4851
### Fixed

src/pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def max_batches(self) -> List[int]:
6060
@property
6161
def dataloaders(self) -> Sequence[DataLoader]:
6262
"""Returns all prediction dataloaders."""
63-
return self.trainer.predict_dataloaders
63+
dataloaders = self.trainer.predict_dataloaders
64+
return [] if dataloaders is None else dataloaders
6465

6566
@property
6667
def skip(self) -> bool:

src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,9 @@ def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]:
162162
"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our
163163
:class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""
164164
# the batch_sampler is not be defined in case of CombinedDataLoaders
165+
assert self.trainer.predict_dataloaders
165166
batch_sampler = getattr(
166-
self.trainer.predict_dataloaders[dataloader_idx], # type: ignore[has-type]
167+
self.trainer.predict_dataloaders[dataloader_idx],
167168
"batch_sampler",
168169
None,
169170
)

src/pytorch_lightning/loops/fit_loop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def on_run_start(self) -> None: # type: ignore[override]
210210

211211
self.trainer.reset_train_dataloader(self.trainer.lightning_module)
212212
# reload the evaluation dataloaders too for proper display in the progress bar
213-
self.epoch_loop.val_loop._reload_evaluation_dataloaders()
213+
if self.epoch_loop._should_check_val_epoch():
214+
self.epoch_loop.val_loop._reload_evaluation_dataloaders()
214215

215216
data_fetcher_cls = _select_data_fetcher(self.trainer)
216217
self._data_fetcher = data_fetcher_cls(prefetch_batches=self.prefetch_batches)

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,19 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_
6161
def _should_reload_train_dl(self) -> bool:
6262
"""Check if train dataloader should be reloaded."""
6363
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
64-
return n_epochs and (self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs)
64+
return n_epochs and (
65+
self.trainer._last_train_dl_reload_epoch is None
66+
or self.trainer.current_epoch - self.trainer._last_train_dl_reload_epoch >= n_epochs
67+
)
6568

6669
@property
6770
def _should_reload_val_dl(self) -> bool:
6871
"""Check if validation dataloader should be reloaded."""
6972
n_epochs = self.trainer.reload_dataloaders_every_n_epochs
70-
return n_epochs and (self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs)
73+
return n_epochs and (
74+
self.trainer._last_val_dl_reload_epoch is None
75+
or self.trainer.current_epoch - self.trainer._last_val_dl_reload_epoch >= n_epochs
76+
)
7177

7278
def on_trainer_init(
7379
self,

src/pytorch_lightning/trainer/trainer.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -626,12 +626,12 @@ def _setup_on_init(self) -> None:
626626
self.num_sanity_val_batches = []
627627
self.num_test_batches = []
628628
self.num_val_batches = []
629+
self.num_predict_batches = []
629630
self.test_dataloaders = None
630631
self.val_dataloaders = None
631-
self._last_train_dl_reload_epoch = float("-inf")
632-
self._last_val_dl_reload_epoch = float("-inf")
633-
634-
self.num_predict_batches = []
632+
self.predict_dataloaders = None
633+
self._last_train_dl_reload_epoch = None
634+
self._last_val_dl_reload_epoch: Optional[int] = None
635635

636636
def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any:
637637
r"""
@@ -711,8 +711,6 @@ def _fit_impl(
711711
self.state.fn = TrainerFn.FITTING
712712
self.state.status = TrainerStatus.RUNNING
713713
self.training = True
714-
self._last_train_dl_reload_epoch = float("-inf")
715-
self._last_val_dl_reload_epoch = float("-inf")
716714

717715
# if a datamodule comes in as the second arg, then fix it for the user
718716
if isinstance(train_dataloaders, LightningDataModule):
@@ -1923,13 +1921,18 @@ def reset_val_dataloader(self, model: Optional["pl.LightningModule"] = None) ->
19231921
has_step = is_overridden("validation_step", pl_module)
19241922
enable_validation = self.limit_val_batches > 0
19251923
if source.is_defined() and has_step and enable_validation:
1924+
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
1925+
# it should not reload again if it has already reloaded during sanity_check
1926+
if self.state.fn == TrainerFn.FITTING and (
1927+
(self.sanity_checking and self.fit_loop.epoch_loop._should_check_val_epoch())
1928+
or not self.sanity_checking
1929+
):
1930+
self._last_val_dl_reload_epoch = self.current_epoch
1931+
19261932
self.num_val_batches, self.val_dataloaders = self._data_connector._reset_eval_dataloader(
19271933
RunningStage.VALIDATING, model=pl_module
19281934
)
19291935

1930-
# store epoch of dataloader reset for reload_dataloaders_every_n_epochs
1931-
self._last_val_dl_reload_epoch = self.current_epoch
1932-
19331936
def reset_test_dataloader(self, model: Optional["pl.LightningModule"] = None) -> None:
19341937
"""Resets the test dataloader and determines the number of batches.
19351938

tests/tests_pytorch/trainer/test_dataloaders.py

Lines changed: 54 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -946,104 +946,87 @@ def test_dataloaders_load_only_once_no_sanity_check(tmpdir):
946946
assert tracker.mock_calls == expected_sequence
947947

948948

949-
@pytest.mark.parametrize("n", [1, 2])
950-
def test_dataloaders_load_every_n_epochs(tmpdir, n):
951-
train_reload_epochs, val_reload_epochs = [], []
952-
953-
class TestModel(BoringModel):
954-
def train_dataloader(self):
955-
train_reload_epochs.append(self.current_epoch)
956-
return super().train_dataloader()
957-
958-
def val_dataloader(self):
959-
val_reload_epochs.append(self.current_epoch)
960-
return super().val_dataloader()
961-
962-
model = TestModel()
963-
964-
trainer = Trainer(
965-
default_root_dir=tmpdir,
966-
limit_train_batches=0.3,
967-
limit_val_batches=0.3,
968-
reload_dataloaders_every_n_epochs=n,
969-
max_epochs=5,
970-
)
971-
972-
tracker = Mock()
973-
model.train_dataloader = Mock(wraps=model.train_dataloader)
974-
model.val_dataloader = Mock(wraps=model.val_dataloader)
975-
model.test_dataloader = Mock(wraps=model.test_dataloader)
976-
977-
tracker.attach_mock(model.train_dataloader, "train_dataloader")
978-
tracker.attach_mock(model.val_dataloader, "val_dataloader")
979-
tracker.attach_mock(model.test_dataloader, "test_dataloader")
980-
981-
trainer.fit(model)
982-
trainer.test(model)
983-
984-
# Verify the sequence
985-
expected_sequence = [call.val_dataloader(), call.train_dataloader()] # Sanity check first
986-
if n == 1:
987-
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 4
988-
elif n == 2:
989-
expected_sequence += [call.train_dataloader(), call.val_dataloader()] * 2
990-
expected_sequence += [call.test_dataloader()]
991-
992-
assert tracker.mock_calls == expected_sequence
993-
994-
# Verify epoch of reloads
995-
if n == 1:
996-
assert train_reload_epochs == [0, 1, 2, 3, 4]
997-
assert val_reload_epochs == [0, 1, 2, 3, 4]
998-
elif n == 2:
999-
assert train_reload_epochs == [0, 2, 4]
1000-
assert val_reload_epochs == [0, 2, 4]
1001-
1002-
1003949
@pytest.mark.parametrize(
1004-
"n, train_reload_epochs_expect, val_reload_epochs_expect",
950+
(
951+
"num_sanity_val_steps, check_val_every_n_epoch, reload_dataloaders_every_n_epochs,"
952+
" train_reload_epochs_expect,val_reload_epochs_expect,val_step_epochs_expect"
953+
),
1005954
[
1006-
# Sanity check at epoch 0 creates a validation dataloader, but validation is
1007-
# checked (and in this case reloaded) every n epochs starting from epoch n-1
1008-
(3, [0, 2, 4, 6, 8], [0, 2, 5, 8]),
1009-
(5, [0, 2, 4, 6, 8], [0, 4, 9]),
955+
# general case where sanity check reloads the dataloaders for validation on current_epoch=0
956+
(0, 1, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
957+
(1, 1, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
958+
# case where check_val_every_n_epoch < reload_dataloaders_every_n_epochs so expected val_reload_epoch
959+
# and val_step_epoch will be different
960+
(0, 1, 2, [0, 2, 4, 6, 8], [0, 2, 4, 6, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
961+
(1, 1, 2, [0, 2, 4, 6, 8], [2, 4, 6, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
962+
(0, 3, 4, [0, 4, 8], [2, 8], [2, 5, 8]),
963+
(1, 3, 4, [0, 4, 8], [2, 8], [2, 5, 8]),
964+
# case where check_val_every_n_epoch > reload_dataloaders_every_n_epochs so expected val_reload_epoch
965+
# and val_step_epoch will be same
966+
(0, 2, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
967+
(1, 2, 1, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
968+
(0, 3, 2, [0, 2, 4, 6, 8], [2, 5, 8], [2, 5, 8]),
969+
(1, 3, 2, [0, 2, 4, 6, 8], [2, 5, 8], [2, 5, 8]),
970+
(0, 5, 2, [0, 2, 4, 6, 8], [4, 9], [4, 9]),
971+
(1, 5, 2, [0, 2, 4, 6, 8], [4, 9], [4, 9]),
972+
# case where check_val_every_n_epoch = reload_dataloaders_every_n_epochs so expected val_reload_epoch
973+
# and val_step_epoch will be same
974+
(0, 2, 2, [0, 2, 4, 6, 8], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
975+
(1, 2, 2, [0, 2, 4, 6, 8], [1, 3, 5, 7, 9], [1, 3, 5, 7, 9]),
1010976
],
1011977
)
1012978
def test_dataloaders_load_every_n_epochs_infrequent_val(
1013-
tmpdir, n, train_reload_epochs_expect, val_reload_epochs_expect
979+
tmpdir,
980+
num_sanity_val_steps,
981+
check_val_every_n_epoch,
982+
reload_dataloaders_every_n_epochs,
983+
train_reload_epochs_expect,
984+
val_reload_epochs_expect,
985+
val_step_epochs_expect,
1014986
):
1015987
"""Test dataloader reload behavior when infrequently checking validation set (via check_val_every_n_epoch)"""
1016-
train_reload_epochs, val_reload_epochs = [], []
988+
sanity_val_check_epochs, train_reload_epochs, val_reload_epochs = [], [], []
989+
sanity_val_step_epochs, val_step_epochs = [], []
1017990

1018991
class TestModel(BoringModel):
1019992
def train_dataloader(self):
1020993
train_reload_epochs.append(self.current_epoch)
1021994
return super().train_dataloader()
1022995

1023996
def val_dataloader(self):
1024-
val_reload_epochs.append(self.current_epoch)
997+
if self.trainer.sanity_checking:
998+
sanity_val_check_epochs.append(self.current_epoch)
999+
else:
1000+
val_reload_epochs.append(self.current_epoch)
10251001
return super().val_dataloader()
10261002

1003+
def validation_step(self, *args, **kwargs):
1004+
if self.trainer.sanity_checking:
1005+
sanity_val_step_epochs.append(self.current_epoch)
1006+
else:
1007+
val_step_epochs.append(self.current_epoch)
1008+
1009+
return super().validation_step(*args, **kwargs)
1010+
10271011
model = TestModel()
10281012

10291013
trainer = Trainer(
10301014
default_root_dir=tmpdir,
1031-
limit_train_batches=0.3,
1032-
limit_val_batches=0.3,
1033-
check_val_every_n_epoch=n,
1034-
reload_dataloaders_every_n_epochs=2,
1015+
limit_train_batches=1,
1016+
limit_val_batches=1,
1017+
check_val_every_n_epoch=check_val_every_n_epoch,
1018+
reload_dataloaders_every_n_epochs=reload_dataloaders_every_n_epochs,
10351019
max_epochs=10,
1020+
num_sanity_val_steps=num_sanity_val_steps,
10361021
)
1037-
model.test_dataloader = Mock(wraps=model.test_dataloader)
1038-
10391022
trainer.fit(model)
1040-
trainer.test(model)
10411023

10421024
# Verify epoch of reloads
1025+
sanity_val_check_epochs_expect = [0] if num_sanity_val_steps else []
1026+
assert sanity_val_check_epochs == sanity_val_step_epochs == sanity_val_check_epochs_expect
10431027
assert train_reload_epochs == train_reload_epochs_expect
10441028
assert val_reload_epochs == val_reload_epochs_expect
1045-
1046-
model.test_dataloader.assert_called_once()
1029+
assert val_step_epochs == val_step_epochs_expect
10471030

10481031

10491032
def test_dataloaders_load_every_n_epochs_frequent_val(tmpdir):

0 commit comments

Comments
 (0)