Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Renamed `TQDMProgressBar.main_progress_bar` to `TQDMProgressBar.train_progress_bar` ([#16695](https://github.com/Lightning-AI/lightning/pull/16695))

- Marked `lightning.pytorch.utilities.supporters.CombinedDataset` as protected ([#16714](https://github.com/Lightning-AI/lightning/pull/16714))

### Deprecated

Expand Down Expand Up @@ -223,6 +224,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed the unused `lightning.pytorch.utilities.metrics.metrics_to_scalars` function ([#16681](https://github.com/Lightning-AI/lightning/pull/16681))

- Removed the unused `lightning.pytorch.utilities.supporters.{SharedCycleIteratorState,CombinedLoaderIterator}` classes ([#16714](https://github.com/Lightning-AI/lightning/pull/16714))

### Fixed

- Fixed an unintended limitation for calling `save_hyperparameters` on mixin classes that don't subclass `LightningModule`/`LightningDataModule` ([#16369](https://github.com/Lightning-AI/lightning/pull/16369))
Expand Down
12 changes: 2 additions & 10 deletions src/lightning/pytorch/loops/fetchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.data.dataloader import DataLoader

from lightning.fabric.utilities.data import has_len
from lightning.pytorch.trainer.supporters import CombinedLoader
from lightning.pytorch.trainer.supporters import _shutdown_workers_and_reset_iterator, CombinedLoader
from lightning.pytorch.utilities.exceptions import MisconfigurationException


Expand Down Expand Up @@ -45,14 +45,6 @@ def dataloader(self) -> Iterable:
)
return self._dataloader

@property
def loader_iters(self) -> Any:
if self.dataloader_iter is None:
raise MisconfigurationException("The `dataloader_iter` isn't available outside the __iter__ context.")
if isinstance(self.dataloader, CombinedLoader):
return self.dataloader_iter.loader_iters
return self.dataloader_iter

def __iter__(self) -> "_DataFetcher":
self.reset()
self.dataloader_iter = iter(self.dataloader)
Expand Down Expand Up @@ -80,7 +72,7 @@ def teardown(self) -> None:
if isinstance(self._dataloader, CombinedLoader):
self._dataloader.reset()
if isinstance(self._dataloader, DataLoader):
CombinedLoader._shutdown_workers_and_reset_iterator(self._dataloader)
_shutdown_workers_and_reset_iterator(self._dataloader)
self.dataloader_iter = None


Expand Down
25 changes: 5 additions & 20 deletions src/lightning/pytorch/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSamplerWrapper
from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
from lightning.pytorch.trainer.supporters import CombinedLoader, CycleIterator
from lightning.pytorch.trainer.supporters import _LITERAL_SUPPORTED_MODES, CombinedLoader
from lightning.pytorch.utilities.data import _is_dataloader_shuffled, _update_dataloader, has_len_all_ranks
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
Expand All @@ -40,7 +40,7 @@


class DataConnector:
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: _LITERAL_SUPPORTED_MODES = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode
self._train_dataloader_source = _DataLoaderSource(None, "")
Expand Down Expand Up @@ -239,28 +239,17 @@ def _prepare_dataloader(
"""This function handles the following functionalities:

- Injecting a `DistributedDataSamplerWrapper` into the `DataLoader` if on a distributed environment
- Wrapping the datasets and samplers into fault-tolerant components
- Wrapping the dataloader based on strategy-specific logic
"""
if isinstance(dataloader, CombinedLoader):
# apply `_prepare_dataloader` on all the collection of loaders
dataloader.loaders = apply_to_collection(
dataloader.loaders, (DataLoader, CycleIterator), self._prepare_dataloader, shuffle, mode=mode
)
# the length need to recomputed across all dataloaders in case of special behavior.
dataloader._apply_cycle_iterator_length()
for i, dl in enumerate(dataloader._loaders_flattened):
dataloader._update_index(self._prepare_dataloader(dl, shuffle=shuffle, mode=mode), i)
return dataloader

# don't do anything if it's not a dataloader
if not isinstance(dataloader, (DataLoader, CycleIterator)):
if not isinstance(dataloader, DataLoader):
return dataloader

cycle_iterator: Optional[CycleIterator] = None

if isinstance(dataloader, CycleIterator):
cycle_iterator = dataloader
dataloader = dataloader.loader

if (
self._requires_distributed_sampler(dataloader) # sets the distributed sampler
or mode == RunningStage.PREDICTING # to track indices for the predictions
Expand All @@ -277,10 +266,6 @@ def _prepare_dataloader(

dataloader = self.trainer.strategy.process_dataloader(dataloader)

if cycle_iterator is not None:
cycle_iterator.loader = dataloader
return cycle_iterator

return dataloader

def _resolve_sampler(
Expand Down
Loading