Skip to content

Commit d77132b

Browse files
authored
Avoid expensive iter() call to dataloader in dataloader checks (#18415)
1 parent 722fdea commit d77132b

File tree

4 files changed

+27
-9
lines changed

4 files changed

+27
-9
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
244244
- Fixed setting the tracking uri in `MLFlowLogger` for logging artifacts to the MLFlow server ([#18395](https://github.com/Lightning-AI/lightning/pull/18395))
245245

246246

247+
- Fixed redundant `iter()` call to dataloader when checking dataloading configuration ([#18415](https://github.com/Lightning-AI/lightning/pull/18415))
248+
249+
247250
## [2.0.5] - 2023-07-07
248251

249252
### Fixed

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,10 @@ def _check_dataloader_iterable(
393393
source: _DataLoaderSource,
394394
trainer_fn: TrainerFn,
395395
) -> None:
396+
if isinstance(dataloader, DataLoader):
397+
# Fast path: `torch.utils.data.DataLoader` is always iterable, calling iter() would be expensive
398+
return
399+
396400
try:
397401
iter(dataloader) # type: ignore[call-overload]
398402
except TypeError:

tests/tests_pytorch/loops/test_loops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -818,17 +818,13 @@ def _get_iterator(self):
818818
expected = [trainer.current_epoch, trainer.current_epoch] # once epoch end, once on teardown
819819
elif should_fail:
820820
expected = [
821-
# iterable check
822-
0,
823821
# epoch ends
824822
1,
825823
# teardown
826824
1,
827825
]
828826
else:
829827
expected = [
830-
# iterable check
831-
0,
832828
# epoch ends
833829
1,
834830
2,
@@ -843,8 +839,6 @@ def _get_iterator(self):
843839
expected = [
844840
# sanity check
845841
0,
846-
# iterable check
847-
0,
848842
# epoch ends
849843
0,
850844
1,
@@ -853,8 +847,6 @@ def _get_iterator(self):
853847
expected = [
854848
# sanity check
855849
0,
856-
# iterable check
857-
0,
858850
# epoch ends
859851
0,
860852
1,

tests/tests_pytorch/trainer/connectors/test_data_connector.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,12 @@
2626
from lightning.fabric.utilities.warnings import PossibleUserWarning
2727
from lightning.pytorch import Trainer
2828
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
29-
from lightning.pytorch.trainer.connectors.data_connector import _DataHookSelector, _DataLoaderSource, warning_cache
29+
from lightning.pytorch.trainer.connectors.data_connector import (
30+
_check_dataloader_iterable,
31+
_DataHookSelector,
32+
_DataLoaderSource,
33+
warning_cache,
34+
)
3035
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
3136
from lightning.pytorch.utilities.combined_loader import CombinedLoader
3237
from lightning.pytorch.utilities.data import _update_dataloader
@@ -643,3 +648,17 @@ def test_non_iterables_raise(tmp_path, trainer_fn_name, dataloader_name, stage,
643648
setattr(model, dl_method, lambda: dataloader)
644649
with pytest.raises(TypeError, match=f"invalid dataloader was returned from `BoringModel.{dl_method}"):
645650
trainer_fn(model)
651+
652+
653+
def test_iterable_check_on_known_iterators():
654+
"""Test that we only call the `iter()` on the dataloader object if it isn't a known type."""
655+
iterable = Mock()
656+
iterable.__iter__ = Mock(return_value=iter(range(3)))
657+
_check_dataloader_iterable(iterable, Mock(), Mock())
658+
iterable.__iter__.assert_called_once()
659+
660+
# If it's a datalaoder, we don't call the expensive `__iter__` method
661+
dataloader = Mock(spec=DataLoader)
662+
dataloader.__iter__ = Mock()
663+
_check_dataloader_iterable(dataloader, Mock(), Mock())
664+
dataloader.__iter__.assert_not_called()

0 commit comments

Comments
 (0)