11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
-
15
- from collections .abc import Sized
14
+ import functools
16
15
from dataclasses import dataclass , field
17
16
from typing import Any , Callable , Dict , Iterable , Iterator , List , Mapping , Optional , Sequence , Union
18
17
19
18
from lightning_utilities .core .apply_func import apply_to_collection
20
19
from torch .utils .data import Dataset
21
20
from torch .utils .data .dataloader import _BaseDataLoaderIter , _MultiProcessingDataLoaderIter , DataLoader
22
21
from torch .utils .data .dataset import IterableDataset
22
+ from typing_extensions import TypedDict
23
23
24
+ from lightning .fabric .utilities .data import sized_len
24
25
from lightning .pytorch .utilities .exceptions import MisconfigurationException
26
+ from lightning .pytorch .utilities .types import _NUMBER
25
27
26
28
27
29
@dataclass
@@ -56,16 +58,12 @@ def done(self) -> bool:
56
58
class CycleIterator :
57
59
"""Iterator for restarting a dataloader if it runs out of samples."""
58
60
59
- def __init__ (self , loader : Any , length : Optional [ Union [ int , float ]] = None , state : SharedCycleIteratorState = None ):
61
+ def __init__ (self , loader : Any , length : _NUMBER = float ( "inf" ) , state : SharedCycleIteratorState = None ):
60
62
"""
61
63
Args:
62
64
loader: the loader to restart for cyclic (and optionally infinite) sampling
63
65
length: the number of batches to sample (with restarted loaders if necessary) before raising StopIteration
64
- if None: infinite
65
66
"""
66
- if length is None :
67
- length = float ("inf" )
68
-
69
67
if not state :
70
68
state = SharedCycleIteratorState ()
71
69
state .dataloaders .append (loader )
@@ -125,74 +123,45 @@ def __next__(self) -> Any:
125
123
finally :
126
124
self .counter += 1
127
125
128
- def __len__ (self ) -> Union [int , float ]:
126
+ def __len__ (self ) -> _NUMBER :
127
+ # TODO: returning float here is a hack
129
128
return self .length
130
129
131
130
132
- class CombinedDataset :
133
- """Combine multiple datasets and compute their statistics."""
131
+ class _CombinationMode (TypedDict ):
132
+ name : str
133
+ fn : Callable [[_NUMBER , _NUMBER ], _NUMBER ]
134
+ default : _NUMBER
134
135
135
- COMPUTE_FUNCS = {"min_size" : min , "max_size_cycle" : max }
136
136
137
- def __init__ (self , datasets : Union [Sequence , Mapping ], mode : str = "min_size" ):
137
+ _supported_modes = {
138
+ "min_size" : _CombinationMode (name = "min_size" , fn = min , default = float ("inf" )),
139
+ "max_size_cycle" : _CombinationMode (name = "max_size_cycle" , fn = max , default = float ("-inf" )),
140
+ }
141
+
142
+
143
+ class CombinedDataset :
144
+ """Combine multiple datasets."""
145
+
146
+ def __init__ (self , datasets : Any , mode : str = "min_size" ):
138
147
"""
139
148
Args:
140
- datasets: a sequence/mapping datasets. Can be a collections of torch.utils.Dataset,
141
- Iterable or even None.
149
+ datasets: Collections of Iterables.
142
150
mode: whether to use the minimum number of batches in all samples or the maximum
143
151
number of batches in all samples.
144
152
"""
145
- self .datasets = datasets
146
- if mode not in self .COMPUTE_FUNCS .keys ():
147
- raise MisconfigurationException (
148
- f'You have selected unsupported mode "{ mode } ",'
149
- f" please select one the: { list (self .COMPUTE_FUNCS .keys ())} ."
150
- )
151
- self .mode = mode
153
+ if mode not in _supported_modes :
154
+ raise ValueError (f"Unsupported mode { mode !r} , please select one of: { list (_supported_modes )} ." )
155
+ self ._mode = mode
156
+ self ._datasets = datasets
152
157
153
158
@property
154
- def max_len (self ) -> Union [int , float ]:
155
- return self ._calc_num_data (self .datasets , "max_size_cycle" )
156
-
157
- @property
158
- def min_len (self ) -> Union [int , float ]:
159
- return self ._calc_num_data (self .datasets , "min_size" )
160
-
161
- def _calc_num_data (self , datasets : Union [Sequence , Mapping ], mode : str ) -> Union [int , float ]:
162
- """Compute the length of `CombinedDataset` according to the `mode`.
163
-
164
- Args:
165
- datasets: a sequence/mapping datasets. Can be a collections of torch.utils.data.Dataset,
166
- Iterable or even None.
167
- mode: Determine `CombinedDataset`'s length is the maximum or minimum of
168
- the datasets.
169
-
170
- Returns:
171
- length: the length of `CombinedDataset`
172
- """
173
- if mode not in self .COMPUTE_FUNCS .keys ():
174
- raise MisconfigurationException (f"Invalid Mode: { mode } " )
159
+ def datasets (self ) -> Any :
160
+ return self ._datasets
175
161
176
- # extract the lengths
177
- all_lengths = self ._get_len_recursive (datasets )
178
-
179
- compute_func = self .COMPUTE_FUNCS [mode ]
180
-
181
- if isinstance (all_lengths , (int , float )):
182
- length = all_lengths
183
- else :
184
- length = _nested_calc_num_data (all_lengths , compute_func )
185
-
186
- return length
187
-
188
- def _get_len_recursive (self , data : Any ) -> Union [int , float , List , Dict ]:
189
- if isinstance (data , Dataset ):
190
- assert isinstance (data , Sized )
191
- return len (data )
192
-
193
- if isinstance (data , (float , int )):
162
+ def _get_len_recursive (self , data : Any ) -> Union [int , List , Dict ]:
163
+ if isinstance (data , int ):
194
164
return data
195
-
196
165
if isinstance (data , Mapping ):
197
166
if any (isinstance (v , (Mapping , Sequence , Dataset , Iterable )) for v in data .values ()):
198
167
return {k : self ._get_len_recursive (v ) for k , v in data .items ()}
@@ -201,53 +170,56 @@ def _get_len_recursive(self, data: Any) -> Union[int, float, List, Dict]:
201
170
if any (isinstance (v , (Mapping , Sequence , Dataset , Iterable )) for v in data ):
202
171
return [self ._get_len_recursive (v ) for v in data ]
203
172
204
- return self ._get_len (data )
205
-
206
- @staticmethod
207
- def _get_len (dataset : Any ) -> Union [int , float ]:
208
- try :
209
- return len (dataset )
210
- except (TypeError , NotImplementedError ):
211
- return float ("inf" )
173
+ length = sized_len (data )
174
+ if length is None :
175
+ raise ValueError (f"Couldn't compute the length of { data } " )
176
+ return length
212
177
213
- def __len__ (self ) -> Union [int , float ]:
214
- """Return the minimum length of the datasets."""
215
- return self ._calc_num_data (self .datasets , self .mode )
178
+ @functools .lru_cache (maxsize = 1 )
179
+ def __len__ (self ) -> int :
180
+ """Compute the length of `CombinedDataset` according to the `mode`."""
181
+ all_lengths = self ._get_len_recursive (self .datasets )
182
+ mode = _supported_modes [self ._mode ]
183
+ total_length = _reduce_data (all_lengths , mode ["fn" ], mode ["default" ])
184
+ if isinstance (total_length , float ):
185
+ raise TypeError (f"The total size of the datasets must be an int, found { total_length } " )
186
+ return total_length
216
187
217
188
218
189
class CombinedLoader :
219
- """Combines different dataloaders and allows sampling in parallel. Supported modes are ``"min_size"``, which
220
- raises StopIteration after the shortest loader (the one with the lowest number of batches) is done, and
221
- ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is done,
222
- while cycling through the shorter loaders.
190
+ """Combines different dataloaders and allows sampling in parallel.
191
+
192
+ Args:
193
+ loaders: the loaders to sample from. Can be all kind of collection
194
+ mode:
195
+ * ``"min_size"``, which raises StopIteration after the shortest loader (the one with the lowest number of
196
+ batches) is done.
197
+ * ``"max_size_cycle"`` which raises StopIteration after the longest loader (the one with most batches) is
198
+ done, while cycling through the shorter loaders.
223
199
224
200
Examples:
225
201
>>> loaders = {'a': DataLoader(range(6), batch_size=4),
226
202
... 'b': DataLoader(range(15), batch_size=5)}
227
203
>>> combined_loader = CombinedLoader(loaders, 'max_size_cycle')
204
+ >>> len(combined_loader)
205
+ 3
228
206
>>> for item in combined_loader:
229
207
... print(item)
230
208
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
231
209
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
232
210
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
233
211
>>> combined_loader = CombinedLoader(loaders, 'min_size')
212
+ >>> len(combined_loader)
213
+ 2
234
214
>>> for item in combined_loader:
235
215
... print(item)
236
216
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}
237
217
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
238
218
"""
239
219
240
- SUPPORTED_MODES = ("min_size" , "max_size_cycle" )
241
-
242
220
def __init__ (self , loaders : Any , mode : str = "min_size" ):
243
- """
244
- Args:
245
- loaders: the loaders to sample from. Can be all kind of collection
246
- mode: the mode. Supported are 'min_size' which stops if the shortest loader is exhausted and
247
- 'max_size_cycle' which stops if the longest loader is exhausted and cycles through the smaller ones.
248
- """
249
- if mode not in self .SUPPORTED_MODES :
250
- raise MisconfigurationException (f"Invalid Mode: { mode } " )
221
+ if mode not in _supported_modes :
222
+ raise ValueError (f"Unsupported mode { mode !r} , please select one of: { list (_supported_modes )} ." )
251
223
252
224
self .loaders = loaders
253
225
@@ -257,57 +229,49 @@ def __init__(self, loaders: Any, mode: str = "min_size"):
257
229
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
258
230
self .dataset = CombinedDataset (datasets , mode )
259
231
260
- self .mode = mode
261
-
262
- if self .mode == "max_size_cycle" :
263
- self ._wrap_loaders_max_size_cycle ()
232
+ self ._mode = mode
233
+ self ._wrap_loaders_max_size_cycle ()
264
234
265
235
self ._iterator : Optional [Iterator ] = None # assigned in __iter__
266
236
267
237
@property
268
- def sampler (self ) -> Union [ Iterable , Sequence , Mapping ] :
238
+ def sampler (self ) -> Any :
269
239
"""Return a collections of samplers extracted from loaders."""
270
240
return apply_to_collection (self .loaders , (DataLoader , IterableDataset ), getattr , "sampler" , None )
271
241
272
242
@property
273
- def batch_sampler (self ) -> Union [ Iterable , Sequence , Mapping ] :
243
+ def batch_sampler (self ) -> Any :
274
244
"""Return a collections of batch samplers extracted from loaders."""
275
245
return apply_to_collection (self .loaders , (DataLoader , IterableDataset ), getattr , "batch_sampler" , None )
276
246
277
- def _wrap_loaders_max_size_cycle (self ) -> Any :
247
+ def _wrap_loaders_max_size_cycle (self ) -> None :
278
248
"""Wraps all loaders to make sure they are cycled until the longest loader is exhausted.
279
249
280
250
Returns:
281
251
the wrapped loaders
282
252
"""
283
- from lightning .pytorch .utilities .data import get_len
284
-
285
- all_lengths = apply_to_collection (self .loaders , Iterable , get_len , wrong_dtype = (Sequence , Mapping ))
286
-
287
- length = _nested_calc_num_data (all_lengths , max )
288
-
289
- # multiple loaders
290
- if isinstance (self .loaders , (Sequence , Mapping )):
291
- state = SharedCycleIteratorState ()
292
-
293
- self .loaders = apply_to_collection (
294
- self .loaders , Iterable , CycleIterator , length = length , state = state , wrong_dtype = (Sequence , Mapping )
295
- )
296
- state .reset ()
253
+ if self ._mode != "max_size_cycle" or not isinstance (self .loaders , (Sequence , Mapping )):
254
+ return
255
+ length = self ._calc_num_batches ()
256
+ state = SharedCycleIteratorState ()
257
+ self .loaders = apply_to_collection (
258
+ self .loaders , Iterable , CycleIterator , length = length , state = state , wrong_dtype = (Sequence , Mapping )
259
+ )
260
+ state .reset ()
297
261
298
262
def _apply_cycle_iterator_length (self ) -> None :
299
263
"""When the model is `max_size_cycle`, compute the length across all ``CycleIterator`` and re-assign it to
300
264
all dataloaders."""
301
- from lightning .pytorch .utilities .data import get_len
302
-
303
- if self .mode != "max_size_cycle" :
265
+ if self ._mode != "max_size_cycle" :
304
266
return
305
267
268
+ from lightning .pytorch .utilities .data import get_len
269
+
306
270
def set_len (cycle_iterator : CycleIterator , length : int ) -> None :
307
271
cycle_iterator .length = length
308
272
309
273
all_lengths = apply_to_collection (self .loaders , CycleIterator , lambda c : get_len (c .loader ))
310
- max_length = _nested_calc_num_data (all_lengths , max )
274
+ max_length = _reduce_data (all_lengths , max , float ( "-inf" ) )
311
275
apply_to_collection (self .loaders , CycleIterator , set_len , length = max_length )
312
276
313
277
def __iter__ (self ) -> Any :
@@ -323,27 +287,19 @@ def __getstate__patch__(*_: Any) -> Dict:
323
287
self ._iterator = iterator
324
288
return iterator
325
289
326
- @staticmethod
327
- def _calc_num_batches (loaders : Any , mode : str = "min_size" ) -> Union [int , float ]:
328
- """Compute the length (aka the number of batches) of `CombinedLoader`.
329
-
330
- Args:
331
- loaders: a collections of loaders.
332
- mode: Mode used by the CombinedDataloader
333
-
334
- Returns:
335
- length: the minimum length of loaders
336
- """
290
+ def _calc_num_batches (self ) -> _NUMBER :
337
291
from lightning .pytorch .utilities .data import get_len
338
292
339
- all_lengths = apply_to_collection (loaders , Iterable , get_len , wrong_dtype = (Sequence , Mapping ))
340
-
341
- if isinstance (all_lengths , (int , float )):
342
- return all_lengths
343
- return _nested_calc_num_data (all_lengths , max if mode == "max_size_cycle" else min )
344
-
345
- def __len__ (self ) -> Union [int , float ]:
346
- return self ._calc_num_batches (self .loaders , mode = self .mode )
293
+ all_lengths = apply_to_collection (self .loaders , Iterable , get_len , wrong_dtype = (Sequence , Mapping ))
294
+ mode = _supported_modes [self ._mode ]
295
+ return _reduce_data (all_lengths , mode ["fn" ], mode ["default" ])
296
+
297
+ def __len__ (self ) -> int :
298
+ """Compute the number of batches."""
299
+ length = self ._calc_num_batches ()
300
+ if isinstance (length , float ):
301
+ raise TypeError (f"Number of batches must be an int, found { length } " )
302
+ return length
347
303
348
304
@staticmethod
349
305
def _shutdown_workers_and_reset_iterator (dataloader : DataLoader ) -> None :
@@ -417,25 +373,15 @@ def create_loader_iters(
417
373
return apply_to_collection (loaders , Iterable , iter , wrong_dtype = (Sequence , Mapping ))
418
374
419
375
420
- def _nested_calc_num_data (
421
- data : Union [Mapping , Sequence ], compute_func : Callable [[List [Union [int , float ]]], Union [int , float ]]
422
- ) -> Union [int , float ]:
423
-
424
- if isinstance (data , (float , int )):
425
- return data
426
-
427
- if isinstance (data , Mapping ):
428
- data = list (data .values ())
429
-
430
- if not isinstance (data , Sequence ):
376
+ def _reduce_data (data : Any , pairwise_reduction : Callable [[_NUMBER , _NUMBER ], _NUMBER ], default : _NUMBER ) -> _NUMBER :
377
+ if data is None :
431
378
raise TypeError (f"Expected data to be int, Sequence or Mapping, but got { type (data ).__name__ } " )
432
379
433
- new_data = []
380
+ total = default
434
381
435
- for x in data :
436
- if isinstance (x , (Mapping , Sequence )):
437
- new_data .append (_nested_calc_num_data (x , compute_func ))
438
- else :
439
- new_data .append (x )
382
+ def reduce (v : _NUMBER ) -> None :
383
+ nonlocal total
384
+ total = pairwise_reduction (total , v )
440
385
441
- return compute_func (new_data )
386
+ apply_to_collection (data , (int , float ), reduce )
387
+ return total
0 commit comments