16
16
from typing import Any , Callable , cast , Dict , Generator , List , Optional , Tuple , Union
17
17
18
18
import torch
19
- from lightning_utilities .core .apply_func import apply_to_collection , apply_to_collections
19
+ from lightning_utilities .core .apply_func import apply_to_collection
20
20
from torch import Tensor
21
21
from torchmetrics import Metric
22
22
from typing_extensions import TypedDict
@@ -317,7 +317,6 @@ def __getstate__(self, drop_value: bool = False) -> dict:
317
317
skip .append ("value" )
318
318
d = {k : v for k , v in self .__dict__ .items () if k not in skip }
319
319
d ["meta" ] = d ["meta" ].__getstate__ ()
320
- d ["_class" ] = self .__class__ .__name__
321
320
d ["_is_synced" ] = False # don't consider the state as synced on reload
322
321
return d
323
322
@@ -338,48 +337,9 @@ def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
338
337
return self
339
338
340
339
341
- class _ResultMetricCollection (dict ):
342
- """Dict wrapper for easy access to metadata.
343
-
344
- All of the leaf items should be instances of
345
- :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric`
346
- with the same metadata.
347
- """
348
-
349
- @property
350
- def meta (self ) -> _Metadata :
351
- return next (iter (self .values ())).meta
352
-
353
- @property
354
- def has_tensor (self ) -> bool :
355
- return any (v .is_tensor for v in self .values ())
356
-
357
- def __getstate__ (self , drop_value : bool = False ) -> dict :
358
- def getstate (item : _ResultMetric ) -> dict :
359
- return item .__getstate__ (drop_value = drop_value )
360
-
361
- items = apply_to_collection (dict (self ), _ResultMetric , getstate )
362
- return {"items" : items , "meta" : self .meta .__getstate__ (), "_class" : self .__class__ .__name__ }
363
-
364
- def __setstate__ (self , state : dict , sync_fn : Optional [Callable ] = None ) -> None :
365
- # can't use `apply_to_collection` as it does not recurse items of the same type
366
- items = {k : _ResultMetric ._reconstruct (v , sync_fn = sync_fn ) for k , v in state ["items" ].items ()}
367
- self .update (items )
368
-
369
- @classmethod
370
- def _reconstruct (cls , state : dict , sync_fn : Optional [Callable ] = None ) -> "_ResultMetricCollection" :
371
- rmc = cls ()
372
- rmc .__setstate__ (state , sync_fn = sync_fn )
373
- return rmc
374
-
375
-
376
- _METRIC_COLLECTION = Union [_IN_METRIC , _ResultMetricCollection ]
377
-
378
-
379
340
class _ResultCollection (dict ):
380
- """
381
- Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric` or
382
- :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetricCollection`
341
+ """Collection (dictionary) of
342
+ :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric`
383
343
384
344
Example:
385
345
@@ -404,18 +364,9 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] =
404
364
405
365
@property
406
366
def result_metrics (self ) -> List [_ResultMetric ]:
407
- o = []
408
-
409
- def append_fn (v : _ResultMetric ) -> None :
410
- nonlocal o
411
- o .append (v )
412
-
413
- apply_to_collection (list (self .values ()), _ResultMetric , append_fn )
414
- return o
367
+ return list (self .values ())
415
368
416
- def _extract_batch_size (
417
- self , value : Union [_ResultMetric , _ResultMetricCollection ], batch_size : Optional [int ], meta : _Metadata
418
- ) -> int :
369
+ def _extract_batch_size (self , value : _ResultMetric , batch_size : Optional [int ], meta : _Metadata ) -> int :
419
370
# check if we have extracted the batch size already
420
371
if batch_size is None :
421
372
batch_size = self .batch_size
@@ -424,8 +375,7 @@ def _extract_batch_size(
424
375
return batch_size
425
376
426
377
batch_size = 1
427
- is_tensor = value .is_tensor if isinstance (value , _ResultMetric ) else value .has_tensor
428
- if self .batch is not None and is_tensor and meta .on_epoch and meta .is_mean_reduction :
378
+ if self .batch is not None and value .is_tensor and meta .on_epoch and meta .is_mean_reduction :
429
379
batch_size = extract_batch_size (self .batch )
430
380
self .batch_size = batch_size
431
381
@@ -435,7 +385,7 @@ def log(
435
385
self ,
436
386
fx : str ,
437
387
name : str ,
438
- value : _METRIC_COLLECTION ,
388
+ value : _IN_METRIC ,
439
389
prog_bar : bool = False ,
440
390
logger : bool = True ,
441
391
on_step : bool = False ,
@@ -494,28 +444,19 @@ def log(
494
444
batch_size = self ._extract_batch_size (self [key ], batch_size , meta )
495
445
self .update_metrics (key , value , batch_size )
496
446
497
- def register_key (self , key : str , meta : _Metadata , value : _METRIC_COLLECTION ) -> None :
447
+ def register_key (self , key : str , meta : _Metadata , value : _IN_METRIC ) -> None :
498
448
"""Create one _ResultMetric object per value.
499
449
500
450
Value can be provided as a nested collection
501
451
"""
452
+ metric = _ResultMetric (meta , isinstance (value , Tensor )).to (self .device )
453
+ self [key ] = metric
502
454
503
- def fn (v : _IN_METRIC ) -> _ResultMetric :
504
- metric = _ResultMetric (meta , isinstance (v , Tensor ))
505
- return metric .to (self .device )
506
-
507
- value = apply_to_collection (value , (Tensor , Metric ), fn )
508
- if isinstance (value , dict ):
509
- value = _ResultMetricCollection (value )
510
- self [key ] = value
511
-
512
- def update_metrics (self , key : str , value : _METRIC_COLLECTION , batch_size : int ) -> None :
513
- def fn (result_metric : _ResultMetric , v : Tensor ) -> None :
514
- # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
515
- result_metric .forward (v .to (self .device ), batch_size )
516
- result_metric .has_reset = False
517
-
518
- apply_to_collections (self [key ], value , _ResultMetric , fn )
455
+ def update_metrics (self , key : str , value : _IN_METRIC , batch_size : int ) -> None :
456
+ result_metric = self [key ]
457
+ # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
458
+ result_metric .forward (value .to (self .device ), batch_size )
459
+ result_metric .has_reset = False
519
460
520
461
@staticmethod
521
462
def _get_cache (result_metric : _ResultMetric , on_step : bool ) -> Optional [Tensor ]:
@@ -557,11 +498,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
557
498
558
499
def valid_items (self ) -> Generator :
559
500
"""This function is used to iterate over current valid metrics."""
560
- return (
561
- (k , v )
562
- for k , v in self .items ()
563
- if not (isinstance (v , _ResultMetric ) and v .has_reset ) and self .dataloader_idx == v .meta .dataloader_idx
564
- )
501
+ return ((k , v ) for k , v in self .items () if not v .has_reset and self .dataloader_idx == v .meta .dataloader_idx )
565
502
566
503
def _forked_name (self , result_metric : _ResultMetric , on_step : bool ) -> Tuple [str , str ]:
567
504
name = result_metric .meta .name
@@ -578,23 +515,9 @@ def metrics(self, on_step: bool) -> _METRICS:
578
515
metrics = _METRICS (callback = {}, log = {}, pbar = {})
579
516
580
517
for _ , result_metric in self .valid_items ():
581
-
582
- # extract forward_cache or computed from the _ResultMetric. ignore when the output is None
583
- value = apply_to_collection (result_metric , _ResultMetric , self ._get_cache , on_step , include_none = False )
584
-
585
- # convert metric collection to dict container.
586
- if isinstance (value , _ResultMetricCollection ):
587
- value = dict (value .items ())
588
-
589
- # check if the collection is empty
590
- has_tensor = False
591
-
592
- def any_tensor (_ : Any ) -> None :
593
- nonlocal has_tensor
594
- has_tensor = True
595
-
596
- apply_to_collection (value , Tensor , any_tensor )
597
- if not has_tensor :
518
+ # extract forward_cache or computed from the _ResultMetric
519
+ value = self ._get_cache (result_metric , on_step )
520
+ if not isinstance (value , Tensor ):
598
521
continue
599
522
600
523
name , forked_name = self ._forked_name (result_metric , on_step )
@@ -623,15 +546,12 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non
623
546
if ``None``, both are.
624
547
fx: Function to reset
625
548
"""
626
-
627
- def fn (item : _ResultMetric ) -> None :
549
+ for item in self .values ():
628
550
requested_type = metrics is None or metrics ^ item .is_tensor
629
551
same_fx = fx is None or fx == item .meta .fx
630
552
if requested_type and same_fx :
631
553
item .reset ()
632
554
633
- apply_to_collection (self , _ResultMetric , fn )
634
-
635
555
def to (self , * args : Any , ** kwargs : Any ) -> "_ResultCollection" :
636
556
"""Move all data to the given device."""
637
557
self .update (apply_to_collection (dict (self ), (Tensor , Metric ), move_data_to_device , * args , ** kwargs ))
@@ -664,7 +584,6 @@ def __repr__(self) -> str:
664
584
665
585
def __getstate__ (self , drop_value : bool = True ) -> dict :
666
586
d = self .__dict__ .copy ()
667
- # all the items should be either `_ResultMetric`s or `_ResultMetricCollection`s
668
587
items = {k : v .__getstate__ (drop_value = drop_value ) for k , v in self .items ()}
669
588
return {** d , "items" : items }
670
589
@@ -673,18 +592,11 @@ def __setstate__(
673
592
) -> None :
674
593
self .__dict__ .update ({k : v for k , v in state .items () if k != "items" })
675
594
676
- def setstate (k : str , item : dict ) -> Union [ _ResultMetric , _ResultMetricCollection ] :
595
+ def setstate (k : str , item : dict ) -> _ResultMetric :
677
596
if not isinstance (item , dict ):
678
597
raise ValueError (f"Unexpected value: { item } " )
679
- cls = item ["_class" ]
680
- if cls == _ResultMetric .__name__ :
681
- cls = _ResultMetric
682
- elif cls == _ResultMetricCollection .__name__ :
683
- cls = _ResultMetricCollection
684
- else :
685
- raise ValueError (f"Unexpected class name: { cls } " )
686
598
_sync_fn = sync_fn or (self [k ].meta .sync .fn if k in self else None )
687
- return cls ._reconstruct (item , sync_fn = _sync_fn )
599
+ return _ResultMetric ._reconstruct (item , sync_fn = _sync_fn )
688
600
689
601
items = {k : setstate (k , v ) for k , v in state ["items" ].items ()}
690
602
self .update (items )
0 commit comments