Skip to content

Commit 0ac1f86

Browse files
committed
Remove support for logging multiple metrics together (#16389)
1 parent 3a274fe commit 0ac1f86

File tree

11 files changed

+48
-180
lines changed

11 files changed

+48
-180
lines changed

docs/source-pytorch/extensions/logging.rst

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,11 +115,6 @@ methods to log from anywhere in a :doc:`LightningModule <../common/lightning_mod
115115
self.log("my_metric", x)
116116
117117
118-
# or a dict to get multiple metrics on the same plot if the logger supports it
119-
def training_step(self, batch, batch_idx):
120-
self.log("performance", {"acc": acc, "recall": recall})
121-
122-
123118
# or a dict to log all metrics at once with individual plots
124119
def training_step(self, batch, batch_idx):
125120
self.log_dict({"acc": acc, "recall": recall})

docs/source-pytorch/visualize/logging_intermediate.rst

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,6 @@ then access the logger's API directly
3535

3636
----
3737

38-
****************************************
39-
Track multiple metrics in the same chart
40-
****************************************
41-
If your logger supports plotting multiple metrics on the same chart, pass in a dictionary to *self.log*.
42-
43-
.. code-block:: python
44-
45-
self.log("performance", {"acc": acc, "recall": recall})
46-
47-
----
48-
4938
*********************
5039
Track hyperparameters
5140
*********************

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
9898

9999
- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))
100100

101+
- Removed support for passing a dictionary value to `self.log()` ([#16389](https://github.com/Lightning-AI/lightning/pull/16389))
102+
101103
- Tuner removal
102104
* Removed the deprecated `trainer.tuning` property ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))
103105
* Removed the deprecated `TrainerFn.TUNING` and `RunningStage.TUNING` enums ([#16379](https://github.com/Lightning-AI/lightning/pull/16379))

src/pytorch_lightning/core/module.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
5454
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
5555
from pytorch_lightning.utilities.types import (
56-
_METRIC_COLLECTION,
56+
_METRIC,
5757
EPOCH_OUTPUT,
5858
LRSchedulerPLType,
5959
LRSchedulerTypeUnion,
@@ -337,7 +337,7 @@ def forward(self, x):
337337
def log(
338338
self,
339339
name: str,
340-
value: _METRIC_COLLECTION,
340+
value: _METRIC,
341341
prog_bar: bool = False,
342342
logger: Optional[bool] = None,
343343
on_step: Optional[bool] = None,
@@ -361,7 +361,7 @@ def log(
361361
362362
Args:
363363
name: key to log.
364-
value: value to log. Can be a ``float``, ``Tensor``, ``Metric``, or a dictionary of the former.
364+
value: value to log. Can be a ``float``, ``Tensor``, or a ``Metric``.
365365
prog_bar: if ``True`` logs to the progress bar.
366366
logger: if ``True`` logs to the logger.
367367
on_step: if ``True`` logs at this step. The default value is determined by the hook.
@@ -390,7 +390,7 @@ def log(
390390
# check for invalid values
391391
apply_to_collection(value, dict, self.__check_not_nested, name)
392392
apply_to_collection(
393-
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor, dict)
393+
value, object, self.__check_allowed, name, value, wrong_dtype=(numbers.Number, Metric, Tensor)
394394
)
395395

396396
if self._trainer is None:
@@ -492,7 +492,7 @@ def log(
492492

493493
def log_dict(
494494
self,
495-
dictionary: Mapping[str, _METRIC_COLLECTION],
495+
dictionary: Mapping[str, _METRIC],
496496
prog_bar: bool = False,
497497
logger: Optional[bool] = None,
498498
on_step: Optional[bool] = None,
@@ -514,8 +514,7 @@ def log_dict(
514514
515515
Args:
516516
dictionary: key value pairs.
517-
The values can be a ``float``, ``Tensor``, ``Metric``, a dictionary of the former
518-
or a ``MetricCollection``.
517+
The values can be a ``float``, ``Tensor``, ``Metric``, or ``MetricCollection``.
519518
prog_bar: if ``True`` logs to the progress base.
520519
logger: if ``True`` logs to the logger.
521520
on_step: if ``True`` logs at this step.

src/pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 22 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Any, Callable, cast, Dict, Generator, List, Optional, Tuple, Union
1717

1818
import torch
19-
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
19+
from lightning_utilities.core.apply_func import apply_to_collection
2020
from torch import Tensor
2121
from torchmetrics import Metric
2222
from typing_extensions import TypedDict
@@ -317,7 +317,6 @@ def __getstate__(self, drop_value: bool = False) -> dict:
317317
skip.append("value")
318318
d = {k: v for k, v in self.__dict__.items() if k not in skip}
319319
d["meta"] = d["meta"].__getstate__()
320-
d["_class"] = self.__class__.__name__
321320
d["_is_synced"] = False # don't consider the state as synced on reload
322321
return d
323322

@@ -338,48 +337,9 @@ def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
338337
return self
339338

340339

341-
class _ResultMetricCollection(dict):
342-
"""Dict wrapper for easy access to metadata.
343-
344-
All of the leaf items should be instances of
345-
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric`
346-
with the same metadata.
347-
"""
348-
349-
@property
350-
def meta(self) -> _Metadata:
351-
return next(iter(self.values())).meta
352-
353-
@property
354-
def has_tensor(self) -> bool:
355-
return any(v.is_tensor for v in self.values())
356-
357-
def __getstate__(self, drop_value: bool = False) -> dict:
358-
def getstate(item: _ResultMetric) -> dict:
359-
return item.__getstate__(drop_value=drop_value)
360-
361-
items = apply_to_collection(dict(self), _ResultMetric, getstate)
362-
return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}
363-
364-
def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
365-
# can't use `apply_to_collection` as it does not recurse items of the same type
366-
items = {k: _ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
367-
self.update(items)
368-
369-
@classmethod
370-
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetricCollection":
371-
rmc = cls()
372-
rmc.__setstate__(state, sync_fn=sync_fn)
373-
return rmc
374-
375-
376-
_METRIC_COLLECTION = Union[_IN_METRIC, _ResultMetricCollection]
377-
378-
379340
class _ResultCollection(dict):
380-
"""
381-
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric` or
382-
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetricCollection`
341+
"""Collection (dictionary) of
342+
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric`
383343
384344
Example:
385345
@@ -404,18 +364,9 @@ def __init__(self, training: bool, device: Optional[Union[str, torch.device]] =
404364

405365
@property
406366
def result_metrics(self) -> List[_ResultMetric]:
407-
o = []
408-
409-
def append_fn(v: _ResultMetric) -> None:
410-
nonlocal o
411-
o.append(v)
412-
413-
apply_to_collection(list(self.values()), _ResultMetric, append_fn)
414-
return o
367+
return list(self.values())
415368

416-
def _extract_batch_size(
417-
self, value: Union[_ResultMetric, _ResultMetricCollection], batch_size: Optional[int], meta: _Metadata
418-
) -> int:
369+
def _extract_batch_size(self, value: _ResultMetric, batch_size: Optional[int], meta: _Metadata) -> int:
419370
# check if we have extracted the batch size already
420371
if batch_size is None:
421372
batch_size = self.batch_size
@@ -424,8 +375,7 @@ def _extract_batch_size(
424375
return batch_size
425376

426377
batch_size = 1
427-
is_tensor = value.is_tensor if isinstance(value, _ResultMetric) else value.has_tensor
428-
if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
378+
if self.batch is not None and value.is_tensor and meta.on_epoch and meta.is_mean_reduction:
429379
batch_size = extract_batch_size(self.batch)
430380
self.batch_size = batch_size
431381

@@ -435,7 +385,7 @@ def log(
435385
self,
436386
fx: str,
437387
name: str,
438-
value: _METRIC_COLLECTION,
388+
value: _IN_METRIC,
439389
prog_bar: bool = False,
440390
logger: bool = True,
441391
on_step: bool = False,
@@ -494,28 +444,19 @@ def log(
494444
batch_size = self._extract_batch_size(self[key], batch_size, meta)
495445
self.update_metrics(key, value, batch_size)
496446

497-
def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
447+
def register_key(self, key: str, meta: _Metadata, value: _IN_METRIC) -> None:
498448
"""Create one _ResultMetric object per value.
499449
500450
Value can be provided as a nested collection
501451
"""
452+
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(self.device)
453+
self[key] = metric
502454

503-
def fn(v: _IN_METRIC) -> _ResultMetric:
504-
metric = _ResultMetric(meta, isinstance(v, Tensor))
505-
return metric.to(self.device)
506-
507-
value = apply_to_collection(value, (Tensor, Metric), fn)
508-
if isinstance(value, dict):
509-
value = _ResultMetricCollection(value)
510-
self[key] = value
511-
512-
def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
513-
def fn(result_metric: _ResultMetric, v: Tensor) -> None:
514-
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
515-
result_metric.forward(v.to(self.device), batch_size)
516-
result_metric.has_reset = False
517-
518-
apply_to_collections(self[key], value, _ResultMetric, fn)
455+
def update_metrics(self, key: str, value: _IN_METRIC, batch_size: int) -> None:
456+
result_metric = self[key]
457+
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
458+
result_metric.forward(value.to(self.device), batch_size)
459+
result_metric.has_reset = False
519460

520461
@staticmethod
521462
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
@@ -557,11 +498,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
557498

558499
def valid_items(self) -> Generator:
559500
"""This function is used to iterate over current valid metrics."""
560-
return (
561-
(k, v)
562-
for k, v in self.items()
563-
if not (isinstance(v, _ResultMetric) and v.has_reset) and self.dataloader_idx == v.meta.dataloader_idx
564-
)
501+
return ((k, v) for k, v in self.items() if not v.has_reset and self.dataloader_idx == v.meta.dataloader_idx)
565502

566503
def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]:
567504
name = result_metric.meta.name
@@ -578,23 +515,9 @@ def metrics(self, on_step: bool) -> _METRICS:
578515
metrics = _METRICS(callback={}, log={}, pbar={})
579516

580517
for _, result_metric in self.valid_items():
581-
582-
# extract forward_cache or computed from the _ResultMetric. ignore when the output is None
583-
value = apply_to_collection(result_metric, _ResultMetric, self._get_cache, on_step, include_none=False)
584-
585-
# convert metric collection to dict container.
586-
if isinstance(value, _ResultMetricCollection):
587-
value = dict(value.items())
588-
589-
# check if the collection is empty
590-
has_tensor = False
591-
592-
def any_tensor(_: Any) -> None:
593-
nonlocal has_tensor
594-
has_tensor = True
595-
596-
apply_to_collection(value, Tensor, any_tensor)
597-
if not has_tensor:
518+
# extract forward_cache or computed from the _ResultMetric
519+
value = self._get_cache(result_metric, on_step)
520+
if not isinstance(value, Tensor):
598521
continue
599522

600523
name, forked_name = self._forked_name(result_metric, on_step)
@@ -623,15 +546,12 @@ def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> Non
623546
if ``None``, both are.
624547
fx: Function to reset
625548
"""
626-
627-
def fn(item: _ResultMetric) -> None:
549+
for item in self.values():
628550
requested_type = metrics is None or metrics ^ item.is_tensor
629551
same_fx = fx is None or fx == item.meta.fx
630552
if requested_type and same_fx:
631553
item.reset()
632554

633-
apply_to_collection(self, _ResultMetric, fn)
634-
635555
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
636556
"""Move all data to the given device."""
637557
self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))
@@ -664,7 +584,6 @@ def __repr__(self) -> str:
664584

665585
def __getstate__(self, drop_value: bool = True) -> dict:
666586
d = self.__dict__.copy()
667-
# all the items should be either `_ResultMetric`s or `_ResultMetricCollection`s
668587
items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()}
669588
return {**d, "items": items}
670589

@@ -673,18 +592,11 @@ def __setstate__(
673592
) -> None:
674593
self.__dict__.update({k: v for k, v in state.items() if k != "items"})
675594

676-
def setstate(k: str, item: dict) -> Union[_ResultMetric, _ResultMetricCollection]:
595+
def setstate(k: str, item: dict) -> _ResultMetric:
677596
if not isinstance(item, dict):
678597
raise ValueError(f"Unexpected value: {item}")
679-
cls = item["_class"]
680-
if cls == _ResultMetric.__name__:
681-
cls = _ResultMetric
682-
elif cls == _ResultMetricCollection.__name__:
683-
cls = _ResultMetricCollection
684-
else:
685-
raise ValueError(f"Unexpected class name: {cls}")
686598
_sync_fn = sync_fn or (self[k].meta.sync.fn if k in self else None)
687-
return cls._reconstruct(item, sync_fn=_sync_fn)
599+
return _ResultMetric._reconstruct(item, sync_fn=_sync_fn)
688600

689601
items = {k: setstate(k, v) for k, v in state["items"].items()}
690602
self.update(items)

src/pytorch_lightning/utilities/types.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from argparse import _ArgumentGroup, ArgumentParser
2020
from contextlib import contextmanager
2121
from dataclasses import dataclass
22-
from typing import Any, Dict, Generator, List, Mapping, Optional, Sequence, Type, Union
22+
from typing import Any, Dict, Generator, List, Optional, Sequence, Type, Union
2323

2424
import torch
2525
from torch import Tensor
@@ -31,7 +31,6 @@
3131

3232
_NUMBER = Union[int, float]
3333
_METRIC = Union[Metric, Tensor, _NUMBER]
34-
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
3534
STEP_OUTPUT = Union[Tensor, Dict[str, Any]]
3635
EPOCH_OUTPUT = List[STEP_OUTPUT]
3736
_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ def test_tensor_to_float_conversion(tmpdir):
407407
class TestModel(BoringModel):
408408
def training_step(self, batch, batch_idx):
409409
self.log("a", torch.tensor(0.123), prog_bar=True, on_epoch=False)
410-
self.log("b", {"b1": torch.tensor([1])}, prog_bar=True, on_epoch=False)
411-
self.log("c", {"c1": 2}, prog_bar=True, on_epoch=False)
410+
self.log("b", torch.tensor([1]), prog_bar=True, on_epoch=False)
411+
self.log("c", 2, prog_bar=True, on_epoch=False)
412412
return super().training_step(batch, batch_idx)
413413

414414
trainer = Trainer(
@@ -417,11 +417,11 @@ def training_step(self, batch, batch_idx):
417417
trainer.fit(TestModel())
418418

419419
torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123)
420-
assert trainer.progress_bar_metrics["b"] == {"b1": 1.0}
421-
assert trainer.progress_bar_metrics["c"] == {"c1": 2.0}
420+
assert trainer.progress_bar_metrics["b"] == 1.0
421+
assert trainer.progress_bar_metrics["c"] == 2.0
422422
pbar = trainer.progress_bar_callback.main_progress_bar
423423
actual = str(pbar.postfix)
424-
assert actual.endswith("a=0.123, b={'b1': 1.0}, c={'c1': 2.0}"), actual
424+
assert actual.endswith("a=0.123, b=1.000, c=2.000"), actual
425425

426426

427427
@pytest.mark.parametrize(

tests/tests_pytorch/core/test_metric_result_integration.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -237,11 +237,10 @@ def lightning_log(fx, *args, **kwargs):
237237
lightning_log("training_step", "c", metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
238238
lightning_log("training_step", "a_1", a, on_step=True, on_epoch=True)
239239
lightning_log("training_step", "b_1", b, on_step=False, on_epoch=True)
240-
lightning_log("training_step", "c_1", {"1": c, "2": c}, on_step=True, on_epoch=False)
240+
lightning_log("training_step", "c_1", c, on_step=True, on_epoch=False)
241241

242242
batch_log = result.metrics(on_step=True)["log"]
243243
assert set(batch_log) == {"a_step", "c", "a_1_step", "c_1"}
244-
assert set(batch_log["c_1"]) == {"1", "2"}
245244

246245
result_copy = deepcopy(result)
247246
new_result = _ResultCollection(True, torch.device("cpu"))
@@ -250,7 +249,7 @@ def lightning_log(fx, *args, **kwargs):
250249
assert "fn" not in state_dict["items"]["training_step.a"]["meta"]["_sync"]
251250

252251
assert not new_result.result_metrics
253-
assert len(result.result_metrics) == 7 + epoch > 0
252+
assert len(result.result_metrics) == 6 + epoch > 0
254253

255254
new_result.load_state_dict(
256255
state_dict, metrics={"metric": metric, "metric_b": metric_b, "metric_c": metric_c}

0 commit comments

Comments
 (0)