Skip to content

Commit 90c7c19

Browse files
awaelchlicarmocca
authored andcommitted
Distributed sampling parity between Lite and PyTorch (#16101)
1 parent 9f05f49 commit 90c7c19

File tree

6 files changed

+81
-3
lines changed

6 files changed

+81
-3
lines changed

src/lightning_fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4343
- Merged the implementation of `DDPSpawnStrategy` into `DDPStrategy` and removed `DDPSpawnStrategy` ([#14952](https://github.com/Lightning-AI/lightning/issues/14952))
4444

4545

46+
- The dataloader wrapper returned from `.setup_dataloaders()` now calls `.set_epoch()` on the distributed sampler if one is used ([#16101](https://github.com/Lightning-AI/lightning/issues/16101))
47+
48+
4649
### Deprecated
4750

4851
-

src/lightning_fabric/fabric.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from lightning_utilities.core.rank_zero import rank_zero_warn
2626
from torch import Tensor
2727
from torch.optim import Optimizer
28-
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler
28+
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
2929

3030
from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split
3131
from lightning_fabric.accelerators.accelerator import Accelerator
@@ -639,6 +639,8 @@ def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
639639
def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> DistributedSampler:
640640
kwargs.setdefault("shuffle", isinstance(dataloader.sampler, RandomSampler))
641641
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
642+
if isinstance(dataloader.sampler, (RandomSampler, SequentialSampler)):
643+
return DistributedSampler(dataloader.dataset, **kwargs)
642644
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)
643645

644646
def _prepare_run_method(self) -> None:

src/lightning_fabric/utilities/distributed.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ class DistributedSamplerWrapper(DistributedSampler):
294294
295295
Allows you to use any sampler in distributed mode. It will be automatically used by Lightning in distributed mode if
296296
sampler replacement is enabled.
297+
298+
Note:
299+
The purpose of this wrapper is to take care of sharding the sampler indices. It is up to the underlying
300+
sampler to handle randomness and shuffling. The ``shuffle`` and ``seed`` arguments on this wrapper won't
301+
have any effect.
297302
"""
298303

299304
def __init__(self, sampler: Union[Sampler, Iterable], *args: Any, **kwargs: Any) -> None:

src/lightning_fabric/wrappers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None
151151
self.__dict__.update(dataloader.__dict__)
152152
self._dataloader = dataloader
153153
self._device = device
154+
self._num_iter_calls = 0
154155

155156
@property
156157
def device(self) -> Optional[torch.device]:
@@ -160,6 +161,13 @@ def __len__(self) -> int:
160161
return len(self._dataloader)
161162

162163
def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]:
164+
if hasattr(self._dataloader.sampler, "set_epoch"):
165+
# Without setting the epoch, the distributed sampler would return the same indices every time, even when
166+
# shuffling is enabled. In PyTorch, the user would normally have to call `.set_epoch()` on the sampler.
167+
# In Lite, we take care of this boilerplate code.
168+
self._dataloader.sampler.set_epoch(self._num_iter_calls)
169+
self._num_iter_calls += 1
170+
163171
iterator = iter(self._dataloader)
164172
if self._device is None:
165173
yield from iterator

tests/tests_fabric/test_fabric.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,46 @@ def test_setup_dataloaders_distributed_sampler_shuffle():
405405
for dataloader in shuffle_dataloaders:
406406
seed_everything(1)
407407
dataloader = lite.setup_dataloaders(dataloader)
408-
assert list(t[0].item() for t in iter(dataloader)) == [5, 0, 2, 1]
408+
assert list(t[0].item() for t in iter(dataloader)) == [5, 2, 7, 1]
409+
410+
411+
@pytest.mark.parametrize("shuffle", [True, False])
412+
@pytest.mark.parametrize("batch_size", [1, 2, 3])
413+
def test_setup_dataloaders_distributed_sampler_parity(shuffle, batch_size):
414+
"""Test that the distributed sampler setup in Lite leads to the same sequence of data as in raw PyTorch."""
415+
torch.manual_seed(1)
416+
lite = Fabric(accelerator="cpu", strategy="ddp", devices=2)
417+
# no lite.launch(): pretend we are on rank 0 now
418+
419+
dataset = torch.arange(10)
420+
torch_dataloader = DataLoader(
421+
dataset,
422+
sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=shuffle),
423+
batch_size=batch_size,
424+
)
425+
lite_dataloader = DataLoader(dataset, shuffle=shuffle, batch_size=batch_size)
426+
lite_dataloader = lite.setup_dataloaders(lite_dataloader)
427+
428+
def fetch_epoch(loader):
429+
iterator = iter(loader)
430+
# we fetch 2 batches per epoch
431+
return torch.cat((next(iterator), next(iterator)))
432+
433+
# 1st epoch
434+
# PyTorch users needs to set the epoch, while in Lite it gets handled automatically
435+
torch_dataloader.sampler.set_epoch(0)
436+
torch_data = fetch_epoch(torch_dataloader)
437+
lite_data = fetch_epoch(lite_dataloader)
438+
assert torch.equal(torch_data, lite_data)
439+
440+
# 2nd epoch
441+
# PyTorch users needs to set the epoch, while in Lite it gets handled automatically
442+
torch_dataloader.sampler.set_epoch(1)
443+
torch_data = fetch_epoch(torch_dataloader)
444+
lite_data = fetch_epoch(lite_dataloader)
445+
assert torch.equal(torch_data, lite_data)
446+
assert torch_dataloader.sampler.epoch == 1
447+
assert lite_dataloader._dataloader.sampler.epoch == 1
409448

410449

411450
@mock.patch.dict(os.environ, {}, clear=True)

tests/tests_fabric/test_wrappers.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from unittest.mock import Mock
14+
from unittest.mock import call, Mock
1515

1616
import pytest
1717
import torch
1818
from tests_fabric.helpers.runif import RunIf
19+
from torch.utils.data import DistributedSampler
1920
from torch.utils.data.dataloader import DataLoader
2021

2122
from lightning_fabric.fabric import Fabric
@@ -230,6 +231,26 @@ def test_lite_dataloader_device_placement(src_device_str, dest_device_str):
230231
assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
231232

232233

234+
def test_lite_dataloader_distributed_sampler_set_epoch():
235+
"""Test that the LiteDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
236+
sampler = DistributedSampler(range(3), num_replicas=2, rank=0)
237+
sampler.set_epoch = Mock()
238+
dataloader = DataLoader(range(3), sampler=sampler)
239+
lite_dataloader = _FabricDataLoader(dataloader)
240+
iterator_epoch_0 = iter(lite_dataloader)
241+
dataloader.sampler.set_epoch.assert_not_called()
242+
next(iterator_epoch_0)
243+
# .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
244+
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
245+
next(iterator_epoch_0)
246+
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
247+
iterator_epoch_1 = iter(lite_dataloader)
248+
assert dataloader.sampler.set_epoch.call_args_list == [call(0)]
249+
next(iterator_epoch_1)
250+
# with every new iterator call, the epoch increases
251+
assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)]
252+
253+
233254
def test_lite_optimizer_wraps():
234255
"""Test that the FabricOptimizer fully wraps the optimizer."""
235256
optimizer_cls = torch.optim.SGD

0 commit comments

Comments
 (0)