Skip to content

Commit 746c734

Browse files
authored
SequentialMode and dataloader_iter improvements (#16784)
1 parent ad698f0 commit 746c734

File tree

6 files changed

+106
-47
lines changed

6 files changed

+106
-47
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242
- Added a new method `Strategy.on_exception` to the strategy base interface ([#16646](https://github.com/Lightning-AI/lightning/pull/16646))
4343

4444

45-
- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743))
45+
- Added "sequential" mode support to `CombinedLoader` to consume multiple iterables in sequence ([#16743](https://github.com/Lightning-AI/lightning/pull/16743), [#16784](https://github.com/Lightning-AI/lightning/pull/16784))
4646

4747
### Changed
4848

src/lightning/pytorch/loops/epoch/evaluation_epoch_loop.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,12 @@ def advance(
114114
Raises:
115115
StopIteration: If the current batch is None
116116
"""
117-
if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
118-
batch_idx = self.batch_progress.current.ready
119-
batch = next(data_fetcher)
120-
else:
121-
batch_idx, batch = next(data_fetcher)
117+
batch_idx = (
118+
data_fetcher.fetched
119+
if isinstance(data_fetcher, _DataLoaderIterDataFetcher)
120+
else self.batch_progress.current.ready
121+
)
122+
batch = next(data_fetcher)
122123
self.batch_progress.is_last_batch = data_fetcher.done
123124

124125
dataloader_idx = kwargs.get("dataloader_idx", 0)

src/lightning/pytorch/loops/epoch/training_epoch_loop.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,8 @@ def advance(self, data_fetcher: _DataFetcher) -> None:
186186
# we are going to train first so the val loop does not need to restart
187187
self.val_loop.restarting = False
188188

189-
if not isinstance(data_fetcher, _DataLoaderIterDataFetcher):
190-
batch_idx = self.batch_idx + 1
191-
batch = next(data_fetcher)
192-
else:
193-
batch_idx, batch = next(data_fetcher)
189+
batch_idx = data_fetcher.fetched if isinstance(data_fetcher, _DataLoaderIterDataFetcher) else self.batch_idx + 1
190+
batch = next(data_fetcher)
194191
self.batch_progress.is_last_batch = data_fetcher.done
195192

196193
trainer = self.trainer

src/lightning/pytorch/loops/fetchers.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple
15+
from typing import Any, Iterable, Iterator, List, Optional, Sized, Tuple, Union
1616

1717
from torch.utils.data.dataloader import DataLoader
1818

1919
from lightning.fabric.utilities.data import has_len
20-
from lightning.pytorch.trainer.supporters import _shutdown_workers_and_reset_iterator, CombinedLoader
20+
from lightning.pytorch.trainer.supporters import _Sequential, _shutdown_workers_and_reset_iterator, CombinedLoader
2121
from lightning.pytorch.utilities.exceptions import MisconfigurationException
2222

2323

@@ -175,20 +175,35 @@ def training_step(self, dataloader_iter: Iterator, batch_idx: int) -> None:
175175

176176
def __iter__(self) -> "_DataLoaderIterDataFetcher":
177177
super().__iter__()
178-
iterator = self.dataloader_iter
179-
assert iterator is not None
180178
self.iterator = iter(_DataFetcherWrapper(self))
181179
return self
182180

183-
def __next__(self) -> Tuple[int, Iterator]:
184-
if not self.done:
185-
return self.fetched, self.iterator
186-
raise StopIteration
181+
def __next__(self) -> Union["_DataFetcherWrapper", Tuple["_DataFetcherWrapper", int, int]]:
182+
if self.done:
183+
raise StopIteration
184+
assert isinstance(self.iterator, _DataFetcherWrapper)
185+
if self._is_sequential:
186+
sequential_mode = self.dataloader._iterator
187+
assert isinstance(sequential_mode, _Sequential)
188+
batch_idx = sequential_mode._idx
189+
dataloader_idx = sequential_mode._iterator_idx
190+
return self.iterator, batch_idx, dataloader_idx
191+
return self.iterator
192+
193+
@property
194+
def _is_sequential(self) -> bool:
195+
return isinstance(self.dataloader, CombinedLoader) and self.dataloader._mode == "sequential"
187196

188197

189198
class _DataFetcherWrapper(Iterator):
190199
def __init__(self, data_fetcher: _DataLoaderIterDataFetcher) -> None:
191200
self.data_fetcher = data_fetcher
192201

193202
def __next__(self) -> Any:
194-
return super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__()
203+
out = super(_DataLoaderIterDataFetcher, self.data_fetcher).__next__()
204+
if self.data_fetcher._is_sequential:
205+
# avoid breaking change with sequential mode and dataloader_iter. this is okay because
206+
# dataloader_iter + sequential + multiple dataloaders is not supported so the `*_step(..., batch_idx)` value
207+
# and the batch_index we are excluding here will match
208+
return out[0]
209+
return out

src/lightning/pytorch/trainer/supporters.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections.abc import Iterable
15-
from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar
15+
from typing import Any, Callable, Iterator, List, Literal, Optional, Sized, Tuple, Type, TypeVar, Union
1616

1717
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter
1818
from typing_extensions import Self, TypedDict
@@ -74,27 +74,47 @@ def __next__(self) -> List:
7474
return [next(it) for it in self.iterators]
7575

7676

77-
class _Sequential(_ModeIterator[Tuple[int, Any]]):
78-
def __init__(self, iterables: List[Iterable]) -> None:
77+
class _Sequential(_ModeIterator[Tuple[Any, int, int]]):
78+
def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None:
7979
super().__init__(iterables)
8080
self._iterator_idx = 0 # what would be dataloader_idx
8181
self._idx = 0 # what would be batch_idx
82+
self.limits = limits
8283

83-
def __next__(self) -> Tuple[int, Any]:
84+
@property
85+
def limits(self) -> Optional[List[Union[int, float]]]:
86+
"""Optional limits per iterator."""
87+
return self._limits
88+
89+
@limits.setter
90+
def limits(self, limits: Optional[List[Union[int, float]]]) -> None:
91+
if limits is not None and len(limits) != len(self.iterables):
92+
raise ValueError(
93+
f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.iterables)})"
94+
)
95+
self._limits = limits
96+
97+
def __next__(self) -> Tuple[Any, int, int]:
8498
n = len(self.iterators)
85-
if n == 0:
99+
if n == 0 or self._iterator_idx >= n:
86100
raise StopIteration
101+
102+
# if limits are set, go to the correct iterator
103+
if self.limits is not None:
104+
while self.limits[self._iterator_idx] <= self._idx:
105+
self._use_next_iterator()
106+
if self._iterator_idx >= n:
107+
raise StopIteration
108+
87109
try:
88110
out = next(self.iterators[self._iterator_idx])
89111
index = self._idx
90112
self._idx += 1
91-
# the return is enumerated by default
92-
return index, out
113+
# batch, batch_idx, dataloader_idx
114+
return out, index, self._iterator_idx
93115
except StopIteration:
94-
self._iterator_idx += 1
95-
self._idx = 0
96-
if self._iterator_idx >= n:
97-
raise
116+
# try the next iterator
117+
self._use_next_iterator()
98118
return self.__next__()
99119

100120
def __iter__(self) -> Self: # type: ignore[valid-type]
@@ -108,6 +128,10 @@ def reset(self) -> None:
108128
self._iterator_idx = 0
109129
self._idx = 0
110130

131+
def _use_next_iterator(self) -> None:
132+
self._iterator_idx += 1
133+
self._idx = 0
134+
111135

112136
class _CombinationMode(TypedDict):
113137
fn: Callable[[List[int]], int]
@@ -170,28 +194,28 @@ class CombinedLoader(Iterable):
170194
>>> combined_loader = CombinedLoader(iterables, 'max_size_cycle')
171195
>>> len(combined_loader)
172196
3
173-
>>> for item in combined_loader:
174-
... print(item)
197+
>>> for batch in combined_loader:
198+
... print(batch)
175199
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
176200
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
177201
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
178202
>>> combined_loader = CombinedLoader(iterables, 'min_size')
179203
>>> len(combined_loader)
180204
2
181-
>>> for item in combined_loader:
182-
... print(item)
205+
>>> for batch in combined_loader:
206+
... print(batch)
183207
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
184208
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
185209
>>> combined_loader = CombinedLoader(iterables, 'sequential')
186210
>>> len(combined_loader)
187211
5
188-
>>> for item in combined_loader:
189-
... print(*item)
190-
0 tensor([0, 1, 2, 3])
191-
1 tensor([4, 5])
192-
0 tensor([0, 1, 2, 3, 4])
193-
1 tensor([5, 6, 7, 8, 9])
194-
2 tensor([10, 11, 12, 13, 14])
212+
>>> for batch, batch_idx, dataloader_idx in combined_loader:
213+
... print(f"{batch} {batch_idx=} {dataloader_idx=}")
214+
tensor([0, 1, 2, 3]) batch_idx=0 dataloader_idx=0
215+
tensor([4, 5]) batch_idx=1 dataloader_idx=0
216+
tensor([0, 1, 2, 3, 4]) batch_idx=0 dataloader_idx=1
217+
tensor([5, 6, 7, 8, 9]) batch_idx=1 dataloader_idx=1
218+
tensor([10, 11, 12, 13, 14]) batch_idx=2 dataloader_idx=1
195219
"""
196220

197221
def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None:

tests/tests_pytorch/trainer/test_supporters.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,14 @@ def test_combined_loader_modes():
122122
combined_loader = CombinedLoader(iterables, "sequential")
123123
assert combined_loader._iterator is None
124124
assert len(combined_loader) == sum_len
125-
for total_idx, (idx, item) in enumerate(combined_loader):
125+
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
126126
assert isinstance(combined_loader._iterator, _Sequential)
127-
assert isinstance(idx, int)
127+
assert isinstance(batch_idx, int)
128128
assert isinstance(item, Tensor)
129129
assert idx == lengths[-1] - 1
130130
assert total_idx == sum_len - 1
131131
assert total_idx == len(combined_loader) - 1
132+
assert dataloader_idx == len(iterables) - 1
132133

133134
iterables = list(iterables.values())
134135

@@ -156,13 +157,14 @@ def test_combined_loader_modes():
156157
combined_loader = CombinedLoader(iterables, "sequential")
157158
assert combined_loader._iterator is None
158159
assert len(combined_loader) == sum_len
159-
for total_idx, (idx, item) in enumerate(combined_loader):
160+
for total_idx, (item, batch_idx, dataloader_idx) in enumerate(combined_loader):
160161
assert isinstance(combined_loader._iterator, _Sequential)
161-
assert isinstance(idx, int)
162+
assert isinstance(batch_idx, int)
162163
assert isinstance(item, Tensor)
163164
assert idx == lengths[-1] - 1
164165
assert total_idx == sum_len - 1
165166
assert total_idx == len(combined_loader) - 1
167+
assert dataloader_idx == len(iterables) - 1
166168

167169

168170
def test_combined_loader_raises():
@@ -205,7 +207,6 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
205207
has_break = False
206208
for idx, item in enumerate(combined_loader):
207209
assert isinstance(item, Sequence)
208-
assert len(item) == 2 if use_multiple_dataloaders else 1
209210
if not use_multiple_dataloaders and idx == 4:
210211
has_break = True
211212
break
@@ -221,6 +222,27 @@ def test_combined_loader_sequence_iterable_dataset(mode, use_multiple_dataloader
221222
assert idx == expected - 1
222223

223224

225+
@pytest.mark.parametrize(
226+
("limits", "expected"),
227+
[
228+
(None, [("a", 0, 0), ("b", 1, 0), ("c", 2, 0), ("d", 0, 1), ("e", 1, 1)]),
229+
([1, 0], [("a", 0, 0)]),
230+
([0, float("inf")], [("d", 0, 1), ("e", 1, 1)]),
231+
([1, 1], [("a", 0, 0), ("d", 0, 1)]),
232+
],
233+
)
234+
def test_sequential_mode_limits(limits, expected):
235+
iterable1 = ["a", "b", "c"]
236+
iterable2 = ["d", "e"]
237+
iterator = _Sequential([iterable1, iterable2], limits)
238+
assert list(iterator) == expected
239+
240+
241+
def test_sequential_mode_limits_raises():
242+
with pytest.raises(ValueError, match=r"number of limits \(0\) and number of iterables \(2\)"):
243+
_Sequential([0, 1], [])
244+
245+
224246
@pytest.mark.parametrize("lengths", [[4, 6], [5, 5], [6, 4]])
225247
def test_combined_loader_sequence_with_map_and_iterable(lengths):
226248
class MyIterableDataset(IterableDataset):

0 commit comments

Comments
 (0)