Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 6 additions & 21 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1014,43 +1014,28 @@ The pseudocode applies also to the ``val_dataloader``.

.. _replace-sampler-ddp:

replace_sampler_ddp
^^^^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/replace_sampler_ddp.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/replace_sampler_ddp.mp4"></video>

|
use_distributed_sampler
^^^^^^^^^^^^^^^^^^^^^^^

Enables auto adding of :class:`~torch.utils.data.distributed.DistributedSampler`. In PyTorch, you must use it in
distributed settings such as TPUs or multi-node. The sampler makes sure each GPU sees the appropriate part of your data.
By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler.
If you already use a custom sampler, Lightning will wrap it in a way that it samples from your sampler in a distributed manner.
If you want to customize it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
If ``replace_sampler_ddp=True`` and a distributed sampler was already added,
Lightning will not replace the existing one.
See :paramref:`pytorch_lightning.trainer.Trainer.params.use_distributed_sampler`.

.. testcode::

# default used by the Trainer
trainer = Trainer(replace_sampler_ddp=True)
trainer = Trainer(use_distributed_sampler=True)

By setting to False, you have to add your own distributed sampler:

.. code-block:: python

# in your LightningModule or LightningDataModule
def train_dataloader(self):
dataset = ...
# default used by the Trainer
sampler = torch.utils.data.distributed.DistributedSampler(dataset, shuffle=True)
sampler = torch.utils.data.DistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
return dataloader

.. note:: For iterable datasets, we don't do this automatically.


strategy
^^^^^^^^
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/fabric/api/fabric_methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ data tensors to the correct device automatically.
train_data, test_data = fabric.setup_dataloaders(train_data, test_data, move_to_device=False)

# If you don't want Fabric to replace the sampler in the context of distributed training
train_data, test_data = fabric.setup_dataloaders(train_data, test_data, replace_sampler=False)
train_data, test_data = fabric.setup_dataloaders(train_data, test_data, use_distributed_sampler=False)


backward
Expand Down
1 change: 1 addition & 0 deletions src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The selection `Fabric(strategy="ddp_spawn", ...)` no longer falls back to "ddp" when a cluster environment gets detected ([#16780](https://github.com/Lightning-AI/lightning/pull/16780))

- Renamed `setup_dataloaders(replace_sampler=...)` to `setup_dataloaders(use_distributed_sampler=...)` ([#16829](https://github.com/Lightning-AI/lightning/pull/16829))

### Deprecated

Expand Down
20 changes: 12 additions & 8 deletions src/lightning/fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,15 +273,16 @@ def setup_optimizers(self, *optimizers: Optimizer) -> Union[_FabricOptimizer, Tu
return optimizers[0] if len(optimizers) == 1 else tuple(optimizers)

def setup_dataloaders(
self, *dataloaders: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
self, *dataloaders: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True
) -> Union[DataLoader, List[DataLoader]]:
"""Set up one or multiple dataloaders for accelerated training. If you need different settings for each
dataloader, call this method individually for each one.

Args:
*dataloaders: A single dataloader or a sequence of dataloaders.
replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader(s)
for distributed training. If you have a custom sampler defined, set this to this argument to ``False``.
use_distributed_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the
dataloader(s) for distributed training. If you have a custom sampler defined, set this argument
to ``False``.
move_to_device: If set ``True`` (default), moves the data returned by the dataloader(s) automatically to
the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the
returned data.
Expand All @@ -291,21 +292,24 @@ def setup_dataloaders(
"""
self._validate_setup_dataloaders(dataloaders)
dataloaders = [
self._setup_dataloader(dataloader, replace_sampler=replace_sampler, move_to_device=move_to_device)
self._setup_dataloader(
dataloader, use_distributed_sampler=use_distributed_sampler, move_to_device=move_to_device
)
for dataloader in dataloaders
]
dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders
return dataloaders # type: ignore[return-value]

def _setup_dataloader(
self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True
self, dataloader: DataLoader, use_distributed_sampler: bool = True, move_to_device: bool = True
) -> DataLoader:
"""Set up a single dataloader for accelerated training.

Args:
dataloader: The dataloader to accelerate.
replace_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the dataloader
for distributed training. If you have a custom sampler defined, set this to this argument to ``False``.
use_distributed_sampler: If set ``True`` (default), automatically wraps or replaces the sampler on the
dataloader for distributed training. If you have a custom sampler defined, set this argument to
``False``.
move_to_device: If set ``True`` (default), moves the data returned by the dataloader automatically to
the correct device. Set this to ``False`` and alternatively use :meth:`to_device` manually on the
returned data.
Expand All @@ -314,7 +318,7 @@ def _setup_dataloader(
The wrapped dataloader.
"""
sampler = dataloader.sampler
if replace_sampler and self._requires_distributed_sampler(dataloader):
if use_distributed_sampler and self._requires_distributed_sampler(dataloader):
sampler = self._get_distributed_sampler(dataloader, **self._strategy.distributed_sampler_kwargs)

# the dataloader needs to be re-instantiated because we want to update the input arguments (e.g., sampler)
Expand Down
9 changes: 4 additions & 5 deletions src/lightning/fabric/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,6 @@ def _get_default_process_group_backend_for_device(device: torch.device) -> str:
return "nccl" if device.type == "cuda" else "gloo"


# TODO(fabric): The error messages refer to 'replace_sampler_ddp' in PL but Fabric has it named 'replace_sampler'
class _DatasetSamplerWrapper(Dataset):
"""Dataset to create indexes from `Sampler` or `Iterable`"""

Expand All @@ -266,19 +265,19 @@ def __init__(self, sampler: Union[Sampler, Iterable]) -> None:
raise TypeError(
"You seem to have configured a sampler in your DataLoader which"
" does not provide `__len__` method. The sampler was about to be"
" replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`"
" replaced by `DistributedSamplerWrapper` since `use_distributed_sampler`"
" is True and you are using distributed training. Either provide `__len__`"
" method in your sampler, remove it from DataLoader or set `replace_sampler_ddp=False`"
" method in your sampler, remove it from DataLoader or set `use_distributed_sampler=False`"
" if you want to handle distributed sampling yourself."
)
if len(sampler) == float("inf"):
raise TypeError(
"You seem to have configured a sampler in your DataLoader which"
" does not provide finite `__len__` method. The sampler was about to be"
" replaced by `DistributedSamplerWrapper` since `replace_sampler_ddp`"
" replaced by `DistributedSamplerWrapper` since `use_distributed_sampler`"
" is True and you are using distributed training. Either provide `__len__`"
" method in your sampler which returns a finite number, remove it from DataLoader"
" or set `replace_sampler_ddp=False` if you want to handle distributed sampling yourself."
" or set `use_distributed_sampler=False` if you want to handle distributed sampling yourself."
)
self._sampler = sampler
# defer materializing an iterator until it is necessary
Expand Down
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `CombinedLoader.loaders` to `CombinedLoader.iterables` ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))


- Renamed `Trainer(replace_sampler_ddp=...)` to `Trainer(use_distributed_sampler=...)` ([#16829](https://github.com/Lightning-AI/lightning/pull/16829))


- Moved the `CombinedLoader` class from `lightning.pytorch.trainer.supporters` to `lightning.pytorch.combined_loader` ([#16819](https://github.com/Lightning-AI/lightning/pull/16819))


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
precision: _PRECISION_INPUT = "32-true",
sync_batchnorm: bool = False,
benchmark: Optional[bool] = None,
replace_sampler_ddp: bool = True,
use_distributed_sampler: bool = True,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
) -> None:
"""The AcceleratorConnector parses several Trainer arguments and instantiates the Strategy including other
Expand Down Expand Up @@ -120,7 +120,7 @@ def __init__(
A. Class > str
B. Strategy > Accelerator/precision/plugins
"""
self.replace_sampler_ddp = replace_sampler_ddp
self.use_distributed_sampler = use_distributed_sampler
_set_torch_flags(deterministic=deterministic, benchmark=benchmark)

# 1. Parsing flags
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:

def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
return (
self.trainer._accelerator_connector.replace_sampler_ddp
self.trainer._accelerator_connector.use_distributed_sampler
and self.trainer._accelerator_connector.is_distributed
and not isinstance(dataloader.sampler, DistributedSampler)
and not has_iterable_dataset(dataloader)
Expand Down
15 changes: 9 additions & 6 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
benchmark: Optional[bool] = None,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
reload_dataloaders_every_n_epochs: int = 0,
replace_sampler_ddp: bool = True,
use_distributed_sampler: bool = True,
detect_anomaly: bool = False,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
inference_mode: bool = True,
Expand Down Expand Up @@ -251,10 +251,13 @@ def __init__(
reload_dataloaders_every_n_epochs: Set to a non-negative integer to reload dataloaders every n epochs.
Default: ``0``.

replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this
will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for
train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it,
you can set ``replace_sampler_ddp=False`` and add your own distributed sampler.
use_distributed_sampler: Whether to wrap the DataLoader's sampler with
:class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
we don't do this automatically.

strategy: Supports different training strategies with aliases
as well custom strategies.
Expand Down Expand Up @@ -293,7 +296,7 @@ def __init__(
num_nodes=num_nodes,
sync_batchnorm=sync_batchnorm,
benchmark=benchmark,
replace_sampler_ddp=replace_sampler_ddp,
use_distributed_sampler=use_distributed_sampler,
deterministic=deterministic,
precision=precision,
plugins=plugins,
Expand Down
10 changes: 5 additions & 5 deletions tests/tests_fabric/test_fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,13 +362,13 @@ def test_setup_dataloaders_move_to_device(fabric_device_mock):


def test_setup_dataloaders_distributed_sampler_not_needed():
"""Test that replace_sampler option has no effect when no distributed sampler is needed."""
"""Test that `use_distributed_sampler` option has no effect when no distributed sampler is needed."""
custom_sampler = Mock(spec=Sampler)
dataloader = DataLoader(Mock(), sampler=custom_sampler)

# keep the custom sampler when not needed to replace
fabric = EmptyFabric()
fabric_dataloader = fabric.setup_dataloaders(dataloader, replace_sampler=True)
fabric_dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=True)
assert fabric_dataloader.sampler is custom_sampler


Expand Down Expand Up @@ -469,10 +469,10 @@ def test_setup_dataloaders_replace_custom_sampler(strategy):
fabric = EmptyFabric(accelerator="cpu", strategy=strategy, devices=2)
if hasattr(fabric.strategy, "distributed_sampler_kwargs"):
with pytest.raises(TypeError, match="You seem to have configured a sampler in your DataLoader"):
fabric.setup_dataloaders(dataloader, replace_sampler=True)
fabric.setup_dataloaders(dataloader, use_distributed_sampler=True)

# setting `replace_sampler=False` leaves the sampler untouched
fabric_dataloader = fabric.setup_dataloaders(dataloader, replace_sampler=False)
# setting `use_distributed_sampler=False` leaves the sampler untouched
fabric_dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False)
assert fabric_dataloader.sampler is custom_sampler


Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/benchmarks/test_basic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
accelerator="gpu" if device_type == "cuda" else "cpu",
devices=1,
logger=False,
replace_sampler_ddp=False,
use_distributed_sampler=False,
benchmark=False,
)
trainer.fit(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_sync_batchnorm_parity(tmpdir):
max_steps=3,
sync_batchnorm=True,
num_sanity_val_steps=0,
replace_sampler_ddp=False,
use_distributed_sampler=False,
deterministic=True,
benchmark=False,
enable_progress_bar=False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def __init__(self, num_feat, dataset, **kwargs):


def test_update_dataloader_with_multiprocessing_context():
"""This test verifies that replace_sampler conserves multiprocessing context."""
"""This test verifies that `use_distributed_sampler` conserves multiprocessing context."""
train = RandomDataset(32, 64)
context = "spawn"
train = DataLoader(train, batch_size=32, num_workers=2, multiprocessing_context=context, shuffle=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/tests_pytorch/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def test_dataloader_distributed_sampler_already_attached(tmpdir):
default_root_dir=tmpdir,
max_steps=100,
callbacks=[DistribSamplerCallback(expected_seeds=(11, 123, 0))],
replace_sampler_ddp=True,
use_distributed_sampler=True,
)
trainer.fit(model)
assert trainer.state.finished, "DDP Training failed"
Expand Down
Loading