Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for saving and loading DeepSpeed checkpoints through `Fabric.save/load()` ([#16452](https://github.com/Lightning-AI/lightning/pull/16452))


- Added support for automatically calling `set_epoch` on the `dataloader.batch_sampler.sampler` ([#16841](https://github.com/Lightning-AI/lightning/pull/16841))


### Changed

- Checkpoint saving and loading redesign ([#16434](https://github.com/Lightning-AI/lightning/pull/16434))
Expand Down Expand Up @@ -59,6 +62,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed issue where the wrapped dataloader `iter()` would be called twice ([#16841](https://github.com/Lightning-AI/lightning/pull/16841))

- Fixed an issue causing a wrong environment plugin to be selected when `accelerator=tpu` and `devices > 1` ([#16806](https://github.com/Lightning-AI/lightning/pull/16806))
- Fixed parsing of defaults for `--accelerator` and `--precision` in Fabric CLI when `accelerator` and `precision` are set to non-default values in the code ([#16818](https://github.com/Lightning-AI/lightning/pull/16818))

Expand Down
22 changes: 22 additions & 0 deletions src/lightning/fabric/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,25 @@ def _replace_value_in_saved_args(
return True, args, kwargs

return False, args, kwargs


def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.

Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""
objects = set()
# check dataloader.sampler
if (sampler := getattr(dataloader, "sampler", None)) is not None:
objects.add(sampler)
# check dataloader.batch_sampler.sampler
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
sampler := getattr(batch_sampler, "sampler", None)
) is not None:
objects.add(sampler)
for obj in objects:
set_epoch = getattr(obj, "set_epoch", None)
if callable(set_epoch):
set_epoch(epoch)
20 changes: 9 additions & 11 deletions src/lightning/fabric/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.strategies import Strategy
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
from lightning.fabric.utilities.types import Optimizable

Expand Down Expand Up @@ -168,20 +169,17 @@ def __len__(self) -> int:
return len(self._dataloader)

def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
if hasattr(self._dataloader.sampler, "set_epoch"):
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
# In Lite, we take care of this boilerplate code.
self._dataloader.sampler.set_epoch(self._num_iter_calls)
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
# In Fabric, we take care of this boilerplate code.
_set_sampler_epoch(self._dataloader, self._num_iter_calls)
self._num_iter_calls += 1

iterator = iter(self._dataloader)
if self._device is None:
yield from iterator
return

for item in iterator:
yield move_data_to_device(item, self._device)
yield from iter(self._dataloader)
else:
for item in self._dataloader:
yield move_data_to_device(item, self._device)


def _process_optimizer_zero_grad_kwargs(optimizer: Optimizer, kwargs: Dict[str, Any]) -> Dict[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/loops/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
from torch import Tensor

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import BatchProgress
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
from typing import Optional, Union

import lightning.pytorch as pl
from lightning.fabric.utilities.data import _auto_add_worker_init_fn
from lightning.fabric.utilities.data import _auto_add_worker_init_fn, _set_sampler_epoch
from lightning.pytorch.loops import _Loop
from lightning.pytorch.loops.fetchers import _DataFetcher
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.training_epoch_loop import _TrainingEpochLoop
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher, _set_sampler_epoch
from lightning.pytorch.loops.utilities import _is_max_limit_reached, _select_data_fetcher
from lightning.pytorch.trainer import call
from lightning.pytorch.trainer.connectors.data_connector import _DataLoaderSource
from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection
Expand Down
3 changes: 2 additions & 1 deletion src/lightning/pytorch/loops/prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import lightning.pytorch as pl
from lightning.fabric.utilities import move_data_to_device
from lightning.fabric.utilities.data import _set_sampler_epoch
from lightning.pytorch.callbacks import BasePredictionWriter
from lightning.pytorch.loops.fetchers import _DataFetcher, _DataLoaderIterDataFetcher
from lightning.pytorch.loops.loop import _Loop
from lightning.pytorch.loops.progress import Progress
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher, _set_sampler_epoch
from lightning.pytorch.loops.utilities import _no_grad_context, _select_data_fetcher
from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper
from lightning.pytorch.strategies import DDPSpawnStrategy
from lightning.pytorch.trainer import call
Expand Down
24 changes: 1 addition & 23 deletions src/lightning/pytorch/loops/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Generator, Iterable, Optional, Tuple
from typing import Any, Callable, Generator, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -123,28 +123,6 @@ def _reset_progress(loop: _Loop) -> None:
_reset_progress(v)


def _set_sampler_epoch(dataloader: Iterable, epoch: int) -> None:
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader.

Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off.
"""
objects = set()
# check dataloader.sampler
if (sampler := getattr(dataloader, "sampler", None)) is not None:
objects.add(sampler)
# check dataloader.batch_sampler.sampler
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and (
sampler := getattr(batch_sampler, "sampler", None)
) is not None:
objects.add(sampler)
for obj in objects:
set_epoch = getattr(obj, "set_epoch", None)
if callable(set_epoch):
set_epoch(epoch)


def _select_data_fetcher(trainer: "pl.Trainer") -> _DataFetcher:
lightning_module = trainer.lightning_module
if trainer.testing:
Expand Down
30 changes: 21 additions & 9 deletions tests/tests_fabric/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import torch
from tests_fabric.helpers.runif import RunIf
from torch.utils.data import DistributedSampler
from torch.utils.data import BatchSampler, DistributedSampler
from torch.utils.data.dataloader import DataLoader

from lightning.fabric.fabric import Fabric
Expand Down Expand Up @@ -232,24 +232,36 @@ def test_fabric_dataloader_device_placement(src_device_str, dest_device_str):
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))


def test_fabric_dataloader_distributed_sampler_set_epoch():
@pytest.mark.parametrize("use_batch_sampler", (False, True))
def test_fabric_dataloader_distributed_sampler_set_epoch(use_batch_sampler):
"""Test that the FabricDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
sampler = DistributedSampler(range(3), num_replicas=2, rank=0)
dataset = range(3)
sampler = DistributedSampler(dataset, num_replicas=2, rank=0)
sampler.set_epoch = Mock()
dataloader = DataLoader(range(3), sampler=sampler)

if not use_batch_sampler:
dataloader = DataLoader(dataset, sampler=sampler)
else:
batch_sampler = BatchSampler(sampler, batch_size=1, drop_last=False)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

fabric_dataloader = _FabricDataLoader(dataloader)
iterator_epoch_0 = iter(fabric_dataloader)
dataloader.sampler.set_epoch.assert_not_called()
sampler.set_epoch.assert_not_called()

next(iterator_epoch_0)
# .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
assert sampler.set_epoch.mock_calls == [call(0)]

next(iterator_epoch_0)
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
assert sampler.set_epoch.mock_calls == [call(0)]

iterator_epoch_1 = iter(fabric_dataloader)
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
assert sampler.set_epoch.mock_calls == [call(0)]

next(iterator_epoch_1)
# with every new iterator call, the epoch increases
assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]
assert sampler.set_epoch.mock_calls == [call(0), call(1)]


def test_fabric_optimizer_wraps():
Expand Down
22 changes: 22 additions & 0 deletions tests/tests_fabric/utilities/test_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
from unittest.mock import Mock

import numpy as np
import pytest
Expand All @@ -12,6 +13,7 @@
_get_dataloader_init_args_and_kwargs,
_replace_dunder_methods,
_replace_value_in_saved_args,
_set_sampler_epoch,
_update_dataloader,
_WrapAttrTag,
has_iterable_dataset,
Expand Down Expand Up @@ -525,3 +527,23 @@ def __init__(self, indices=None, **kwargs):
dataloader = ArrayAttributeDataloader(dataset)
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, dataloader.sampler)
assert dl_kwargs["indices"] is dataloader.indices


def test_set_sampler_epoch():
# No samplers
dataloader = Mock()
dataloader.sampler = None
dataloader.batch_sampler = None
_set_sampler_epoch(dataloader, 55)

# set_epoch not callable
dataloader = Mock()
dataloader.sampler.set_epoch = None
dataloader.batch_sampler.set_epoch = None
_set_sampler_epoch(dataloader, 55)

# set_epoch callable
dataloader = Mock()
_set_sampler_epoch(dataloader, 55)
dataloader.sampler.set_epoch.assert_called_once_with(55)
dataloader.batch_sampler.sampler.set_epoch.assert_called_once_with(55)
8 changes: 4 additions & 4 deletions tests/tests_pytorch/loops/test_prediction_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def predict_step(self, batch, batch_idx):
assert trainer.predict_loop.predictions == []


@pytest.mark.parametrize("replace_sampler_ddp", (False, True))
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, replace_sampler_ddp):
@pytest.mark.parametrize("use_distributed_sampler", (False, True))
def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler):
"""Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction."""
trainer = Trainer(
default_root_dir=tmp_path,
Expand All @@ -63,14 +63,14 @@ def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, replace_sample
strategy="ddp",
devices=1,
accelerator="cpu",
replace_sampler_ddp=replace_sampler_ddp,
use_distributed_sampler=use_distributed_sampler,
)

class MyModel(BoringModel):
def predict_dataloader(self):
dataset = RandomDataset(32, 64)
sampler = None
if not replace_sampler_ddp:
if not use_distributed_sampler:
sampler = DistributedSampler(dataset)
return DataLoader(dataset, sampler=sampler)

Expand Down
36 changes: 0 additions & 36 deletions tests/tests_pytorch/loops/test_utilities.py

This file was deleted.