Skip to content

Commit 2add5d3

Browse files
authored
Refactor supporters (#16662)
1 parent 35b8543 commit 2add5d3

File tree

2 files changed

+107
-166
lines changed

2 files changed

+107
-166
lines changed

src/lightning/pytorch/trainer/supporters.py

Lines changed: 94 additions & 148 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
from collections.abc import Sized
14+
import functools
1615
from dataclasses import dataclass, field
1716
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, Optional, Sequence, Union
1817

1918
from lightning_utilities.core.apply_func import apply_to_collection
2019
from torch.utils.data import Dataset
2120
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
2221
from torch.utils.data.dataset import IterableDataset
22+
from typing_extensions import TypedDict
2323

24+
from lightning.fabric.utilities.data import sized_len
2425
from lightning.pytorch.utilities.exceptions import MisconfigurationException
26+
from lightning.pytorch.utilities.types import _NUMBER
2527

2628

2729
@dataclass
@@ -56,16 +58,12 @@ def done(self) -> bool:
5658
class CycleIterator:
5759
"""Iterator for restarting a dataloader if it runs out of samples."""
5860

59-
def __init__(self, loader: Any, length: Optional[Union[int, float]] = None, state: SharedCycleIteratorState = None):
61+
def __init__(self, loader: Any, length: _NUMBER = float("inf"), state: SharedCycleIteratorState = None):
6062
"""
6163
Args:
6264
loader: the loader to restart for cyclic (and optionally infinite) sampling
6365
length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration
64-
if None: infinite
6566
"""
66-
if length is None:
67-
length = float("inf")
68-
6967
if not state:
7068
state = SharedCycleIteratorState()
7169
state.dataloaders.append(loader)
@@ -125,74 +123,45 @@ def __next__(self) -> Any:
125123
finally:
126124
self.counter += 1
127125

128-
def __len__(self) -> Union[int, float]:
126+
def __len__(self) -> _NUMBER:
127+
# TODO: returning float here is a hack
129128
return self.length
130129

131130

132-
class CombinedDataset:
133-
"""Combine multiple datasets and compute their statistics."""
131+
class _CombinationMode(TypedDict):
132+
name: str
133+
fn: Callable[[_NUMBER, _NUMBER], _NUMBER]
134+
default: _NUMBER
134135

135-
COMPUTE_FUNCS = {"min_size": min, "max_size_cycle": max}
136136

137-
def __init__(self, datasets: Union[Sequence, Mapping], mode: str = "min_size"):
137+
_supported_modes = {
138+
"min_size": _CombinationMode(name="min_size", fn=min, default=float("inf")),
139+
"max_size_cycle": _CombinationMode(name="max_size_cycle", fn=max, default=float("-inf")),
140+
}
141+
142+
143+
class CombinedDataset:
144+
"""Combine multiple datasets."""
145+
146+
def __init__(self, datasets: Any, mode: str = "min_size"):
138147
"""
139148
Args:
140-
datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset,
141-
Iterable or even None.
149+
datasets: Collections of Iterables.
142150
mode: whether to use the minimum number of batches in all samples or the maximum
143151
number of batches in all samples.
144152
"""
145-
self.datasets = datasets
146-
if mode not in self.COMPUTE_FUNCS.keys():
147-
raise MisconfigurationException(
148-
f'You have selected unsupported mode "{mode}",'
149-
f" please select one the: {list(self.COMPUTE_FUNCS.keys())}."
150-
)
151-
self.mode = mode
153+
if mode not in _supported_modes:
154+
raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_supported_modes)}.")
155+
self._mode = mode
156+
self._datasets = datasets
152157

153158
@property
154-
def max_len(self) -> Union[int, float]:
155-
return self._calc_num_data(self.datasets, "max_size_cycle")
156-
157-
@property
158-
def min_len(self) -> Union[int, float]:
159-
return self._calc_num_data(self.datasets, "min_size")
160-
161-
def _calc_num_data(self, datasets: Union[Sequence, Mapping], mode: str) -> Union[int, float]:
162-
"""Compute the length of `CombinedDataset` according to the `mode`.
163-
164-
Args:
165-
datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset,
166-
Iterable or even None.
167-
mode: Determine `CombinedDataset`'s length is the maximum or minimum of
168-
the datasets.
169-
170-
Returns:
171-
length: the length of `CombinedDataset`
172-
"""
173-
if mode not in self.COMPUTE_FUNCS.keys():
174-
raise MisconfigurationException(f"Invalid Mode: {mode}")
159+
def datasets(self) -> Any:
160+
return self._datasets
175161

176-
# extract the lengths
177-
all_lengths = self._get_len_recursive(datasets)
178-
179-
compute_func = self.COMPUTE_FUNCS[mode]
180-
181-
if isinstance(all_lengths, (int, float)):
182-
length = all_lengths
183-
else:
184-
length = _nested_calc_num_data(all_lengths, compute_func)
185-
186-
return length
187-
188-
def _get_len_recursive(self, data: Any) -> Union[int, float, List, Dict]:
189-
if isinstance(data, Dataset):
190-
assert isinstance(data, Sized)
191-
return len(data)
192-
193-
if isinstance(data, (float, int)):
162+
def _get_len_recursive(self, data: Any) -> Union[int, List, Dict]:
163+
if isinstance(data, int):
194164
return data
195-
196165
if isinstance(data, Mapping):
197166
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data.values()):
198167
return {k: self._get_len_recursive(v) for k, v in data.items()}
@@ -201,53 +170,56 @@ def _get_len_recursive(self, data: Any) -> Union[int, float, List, Dict]:
201170
if any(isinstance(v, (Mapping, Sequence, Dataset, Iterable)) for v in data):
202171
return [self._get_len_recursive(v) for v in data]
203172

204-
return self._get_len(data)
205-
206-
@staticmethod
207-
def _get_len(dataset: Any) -> Union[int, float]:
208-
try:
209-
return len(dataset)
210-
except (TypeError, NotImplementedError):
211-
return float("inf")
173+
length = sized_len(data)
174+
if length is None:
175+
raise ValueError(f"Couldn't compute the length of {data}")
176+
return length
212177

213-
def __len__(self) -> Union[int, float]:
214-
"""Return the minimum length of the datasets."""
215-
return self._calc_num_data(self.datasets, self.mode)
178+
@functools.lru_cache(maxsize=1)
179+
def __len__(self) -> int:
180+
"""Compute the length of `CombinedDataset` according to the `mode`."""
181+
all_lengths = self._get_len_recursive(self.datasets)
182+
mode = _supported_modes[self._mode]
183+
total_length = _reduce_data(all_lengths, mode["fn"], mode["default"])
184+
if isinstance(total_length, float):
185+
raise TypeError(f"The total size of the datasets must be an int, found {total_length}")
186+
return total_length
216187

217188

218189
class CombinedLoader:
219-
"""Combines different dataloaders and allows sampling in parallel. Supported modes are ``"min_size"``, which
220-
raises StopIteration after the shortest loader (the one with the lowest number of batches) is done, and
221-
``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is done,
222-
while cycling through the shorter loaders.
190+
"""Combines different dataloaders and allows sampling in parallel.
191+
192+
Args:
193+
loaders: the loaders to sample from. Can be all kind of collection
194+
mode:
195+
* ``"min_size"``, which raises StopIteration after the shortest loader (the one with the lowest number of
196+
batches) is done.
197+
* ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is
198+
done, while cycling through the shorter loaders.
223199
224200
Examples:
225201
>>> loaders = {'a': DataLoader(range(6), batch_size=4),
226202
... 'b': DataLoader(range(15), batch_size=5)}
227203
>>> combined_loader = CombinedLoader(loaders, 'max_size_cycle')
204+
>>> len(combined_loader)
205+
3
228206
>>> for item in combined_loader:
229207
... print(item)
230208
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
231209
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
232210
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
233211
>>> combined_loader = CombinedLoader(loaders, 'min_size')
212+
>>> len(combined_loader)
213+
2
234214
>>> for item in combined_loader:
235215
... print(item)
236216
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
237217
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
238218
"""
239219

240-
SUPPORTED_MODES = ("min_size", "max_size_cycle")
241-
242220
def __init__(self, loaders: Any, mode: str = "min_size"):
243-
"""
244-
Args:
245-
loaders: the loaders to sample from. Can be all kind of collection
246-
mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and
247-
'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.
248-
"""
249-
if mode not in self.SUPPORTED_MODES:
250-
raise MisconfigurationException(f"Invalid Mode: {mode}")
221+
if mode not in _supported_modes:
222+
raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_supported_modes)}.")
251223

252224
self.loaders = loaders
253225

@@ -257,57 +229,49 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
257229
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
258230
self.dataset = CombinedDataset(datasets, mode)
259231

260-
self.mode = mode
261-
262-
if self.mode == "max_size_cycle":
263-
self._wrap_loaders_max_size_cycle()
232+
self._mode = mode
233+
self._wrap_loaders_max_size_cycle()
264234

265235
self._iterator: Optional[Iterator] = None # assigned in __iter__
266236

267237
@property
268-
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
238+
def sampler(self) -> Any:
269239
"""Return a collections of samplers extracted from loaders."""
270240
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "sampler", None)
271241

272242
@property
273-
def batch_sampler(self) -> Union[Iterable, Sequence, Mapping]:
243+
def batch_sampler(self) -> Any:
274244
"""Return a collections of batch samplers extracted from loaders."""
275245
return apply_to_collection(self.loaders, (DataLoader, IterableDataset), getattr, "batch_sampler", None)
276246

277-
def _wrap_loaders_max_size_cycle(self) -> Any:
247+
def _wrap_loaders_max_size_cycle(self) -> None:
278248
"""Wraps all loaders to make sure they are cycled until the longest loader is exhausted.
279249
280250
Returns:
281251
the wrapped loaders
282252
"""
283-
from lightning.pytorch.utilities.data import get_len
284-
285-
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
286-
287-
length = _nested_calc_num_data(all_lengths, max)
288-
289-
# multiple loaders
290-
if isinstance(self.loaders, (Sequence, Mapping)):
291-
state = SharedCycleIteratorState()
292-
293-
self.loaders = apply_to_collection(
294-
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
295-
)
296-
state.reset()
253+
if self._mode != "max_size_cycle" or not isinstance(self.loaders, (Sequence, Mapping)):
254+
return
255+
length = self._calc_num_batches()
256+
state = SharedCycleIteratorState()
257+
self.loaders = apply_to_collection(
258+
self.loaders, Iterable, CycleIterator, length=length, state=state, wrong_dtype=(Sequence, Mapping)
259+
)
260+
state.reset()
297261

298262
def _apply_cycle_iterator_length(self) -> None:
299263
"""When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
300264
all dataloaders."""
301-
from lightning.pytorch.utilities.data import get_len
302-
303-
if self.mode != "max_size_cycle":
265+
if self._mode != "max_size_cycle":
304266
return
305267

268+
from lightning.pytorch.utilities.data import get_len
269+
306270
def set_len(cycle_iterator: CycleIterator, length: int) -> None:
307271
cycle_iterator.length = length
308272

309273
all_lengths = apply_to_collection(self.loaders, CycleIterator, lambda c: get_len(c.loader))
310-
max_length = _nested_calc_num_data(all_lengths, max)
274+
max_length = _reduce_data(all_lengths, max, float("-inf"))
311275
apply_to_collection(self.loaders, CycleIterator, set_len, length=max_length)
312276

313277
def __iter__(self) -> Any:
@@ -323,27 +287,19 @@ def __getstate__patch__(*_: Any) -> Dict:
323287
self._iterator = iterator
324288
return iterator
325289

326-
@staticmethod
327-
def _calc_num_batches(loaders: Any, mode: str = "min_size") -> Union[int, float]:
328-
"""Compute the length (aka the number of batches) of `CombinedLoader`.
329-
330-
Args:
331-
loaders: a collections of loaders.
332-
mode: Mode used by the CombinedDataloader
333-
334-
Returns:
335-
length: the minimum length of loaders
336-
"""
290+
def _calc_num_batches(self) -> _NUMBER:
337291
from lightning.pytorch.utilities.data import get_len
338292

339-
all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
340-
341-
if isinstance(all_lengths, (int, float)):
342-
return all_lengths
343-
return _nested_calc_num_data(all_lengths, max if mode == "max_size_cycle" else min)
344-
345-
def __len__(self) -> Union[int, float]:
346-
return self._calc_num_batches(self.loaders, mode=self.mode)
293+
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
294+
mode = _supported_modes[self._mode]
295+
return _reduce_data(all_lengths, mode["fn"], mode["default"])
296+
297+
def __len__(self) -> int:
298+
"""Compute the number of batches."""
299+
length = self._calc_num_batches()
300+
if isinstance(length, float):
301+
raise TypeError(f"Number of batches must be an int, found {length}")
302+
return length
347303

348304
@staticmethod
349305
def _shutdown_workers_and_reset_iterator(dataloader: DataLoader) -> None:
@@ -417,25 +373,15 @@ def create_loader_iters(
417373
return apply_to_collection(loaders, Iterable, iter, wrong_dtype=(Sequence, Mapping))
418374

419375

420-
def _nested_calc_num_data(
421-
data: Union[Mapping, Sequence], compute_func: Callable[[List[Union[int, float]]], Union[int, float]]
422-
) -> Union[int, float]:
423-
424-
if isinstance(data, (float, int)):
425-
return data
426-
427-
if isinstance(data, Mapping):
428-
data = list(data.values())
429-
430-
if not isinstance(data, Sequence):
376+
def _reduce_data(data: Any, pairwise_reduction: Callable[[_NUMBER, _NUMBER], _NUMBER], default: _NUMBER) -> _NUMBER:
377+
if data is None:
431378
raise TypeError(f"Expected data to be int, Sequence or Mapping, but got {type(data).__name__}")
432379

433-
new_data = []
380+
total = default
434381

435-
for x in data:
436-
if isinstance(x, (Mapping, Sequence)):
437-
new_data.append(_nested_calc_num_data(x, compute_func))
438-
else:
439-
new_data.append(x)
382+
def reduce(v: _NUMBER) -> None:
383+
nonlocal total
384+
total = pairwise_reduction(total, v)
440385

441-
return compute_func(new_data)
386+
apply_to_collection(data, (int, float), reduce)
387+
return total

0 commit comments

Comments
 (0)