|
17 | 17 | import pytest
|
18 | 18 | import torch
|
19 | 19 | from tests_fabric.helpers.runif import RunIf
|
20 |
| -from torch.utils.data import DistributedSampler |
| 20 | +from torch.utils.data import BatchSampler, DistributedSampler |
21 | 21 | from torch.utils.data.dataloader import DataLoader
|
22 | 22 |
|
23 | 23 | from lightning.fabric.fabric import Fabric
|
@@ -232,24 +232,36 @@ def test_fabric_dataloader_device_placement(src_device_str, dest_device_str):
|
232 | 232 | assert torch.equal(batch1["data"], torch.tensor([2, 3], device=dest_device))
|
233 | 233 |
|
234 | 234 |
|
235 |
| -def test_fabric_dataloader_distributed_sampler_set_epoch(): |
| 235 | +@pytest.mark.parametrize("use_batch_sampler", (False, True)) |
| 236 | +def test_fabric_dataloader_distributed_sampler_set_epoch(use_batch_sampler): |
236 | 237 | """Test that the FabricDataLoader calls `set_epoch()` on the wrapped sampler if applicable."""
|
237 |
| - sampler = DistributedSampler(range(3), num_replicas=2, rank=0) |
| 238 | + dataset = range(3) |
| 239 | + sampler = DistributedSampler(dataset, num_replicas=2, rank=0) |
238 | 240 | sampler.set_epoch = Mock()
|
239 |
| - dataloader = DataLoader(range(3), sampler=sampler) |
| 241 | + |
| 242 | + if not use_batch_sampler: |
| 243 | + dataloader = DataLoader(dataset, sampler=sampler) |
| 244 | + else: |
| 245 | + batch_sampler = BatchSampler(sampler, batch_size=1, drop_last=False) |
| 246 | + dataloader = DataLoader(dataset, batch_sampler=batch_sampler) |
| 247 | + |
240 | 248 | fabric_dataloader = _FabricDataLoader(dataloader)
|
241 | 249 | iterator_epoch_0 = iter(fabric_dataloader)
|
242 |
| - dataloader.sampler.set_epoch.assert_not_called() |
| 250 | + sampler.set_epoch.assert_not_called() |
| 251 | + |
243 | 252 | next(iterator_epoch_0)
|
244 | 253 | # .set_epoch() gets called before the first sample gets fetched from the wrapped dataloader
|
245 |
| - assert dataloader.sampler.set_epoch.call_args_list == [call(0)] |
| 254 | + assert sampler.set_epoch.mock_calls == [call(0)] |
| 255 | + |
246 | 256 | next(iterator_epoch_0)
|
247 |
| - assert dataloader.sampler.set_epoch.call_args_list == [call(0)] |
| 257 | + assert sampler.set_epoch.mock_calls == [call(0)] |
| 258 | + |
248 | 259 | iterator_epoch_1 = iter(fabric_dataloader)
|
249 |
| - assert dataloader.sampler.set_epoch.call_args_list == [call(0)] |
| 260 | + assert sampler.set_epoch.mock_calls == [call(0)] |
| 261 | + |
250 | 262 | next(iterator_epoch_1)
|
251 | 263 | # with every new iterator call, the epoch increases
|
252 |
| - assert dataloader.sampler.set_epoch.call_args_list == [call(0), call(1)] |
| 264 | + assert sampler.set_epoch.mock_calls == [call(0), call(1)] |
253 | 265 |
|
254 | 266 |
|
255 | 267 | def test_fabric_optimizer_wraps():
|
|
0 commit comments