Skip to content

Commit 1ec4e02

Browse files
awaelchlicarmocca
andcommitted
Remove special handling of loss in progress bar (#16192)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent a141d04 commit 1ec4e02

File tree

12 files changed

+9
-192
lines changed

12 files changed

+9
-192
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6565
* Removed the `LoggerConnector.on_train_split_start` method
6666

6767
- Removed the `LightningModule.precision` attribute ([#16203](https://github.com/Lightning-AI/lightning/pull/16203))
68+
- Removed the automatic addition of a moving average of the `training_step` loss in the progress bar. Use `self.log("loss", ..., prog_bar=True)` instead. ([#16192](https://github.com/Lightning-AI/lightning/issues/16192))
6869

6970

7071
### Fixed

src/pytorch_lightning/callbacks/progress/base.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def get_metrics(self, trainer, model):
242242
Return:
243243
Dictionary with the items to be displayed in the progress bar.
244244
"""
245-
standard_metrics = get_standard_metrics(trainer, pl_module)
245+
standard_metrics = get_standard_metrics(trainer)
246246
pbar_metrics = trainer.progress_bar_metrics
247247
duplicates = list(standard_metrics.keys() & pbar_metrics.keys())
248248
if duplicates:
@@ -255,30 +255,20 @@ def get_metrics(self, trainer, model):
255255
return {**standard_metrics, **pbar_metrics}
256256

257257

258-
def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Dict[str, Union[int, str]]:
258+
def get_standard_metrics(trainer: "pl.Trainer") -> Dict[str, Union[int, str]]:
259259
r"""
260-
Returns several standard metrics displayed in the progress bar, including the average loss value,
261-
split index of BPTT (if used) and the version of the experiment when using a logger.
260+
Returns the standard metrics displayed in the progress bar.
261+
Currently, it only includes the version of the experiment when using a logger.
262262
263263
.. code-block::
264264
265-
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, loss=4.501, v_num=10]
265+
Epoch 1: 4%|▎ | 40/1095 [00:03<01:37, 10.84it/s, v_num=10]
266266
267267
Return:
268268
Dictionary with the standard metrics to be displayed in the progress bar.
269269
"""
270-
# call .item() only once but store elements without graphs
271-
running_train_loss = trainer.fit_loop.running_loss.mean()
272-
avg_training_loss = None
273-
if running_train_loss is not None:
274-
avg_training_loss = running_train_loss.cpu().item()
275-
elif pl_module.automatic_optimization:
276-
avg_training_loss = float("NaN")
277270

278271
items_dict: Dict[str, Union[int, str]] = {}
279-
if avg_training_loss is not None:
280-
items_dict["loss"] = f"{avg_training_loss:.3g}"
281-
282272
if trainer.loggers:
283273
version = _version(trainer.loggers)
284274
if version is not None:

src/pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import torch
2020
from lightning_utilities.core.apply_func import apply_to_collection
21-
from torch import Tensor
2221

2322
import pytorch_lightning as pl
2423
from pytorch_lightning import loops # import as loops to avoid circular imports
@@ -28,7 +27,7 @@
2827
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached
2928
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
3029
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
31-
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
30+
from pytorch_lightning.trainer.supporters import CombinedLoader
3231
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
3332
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3433
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
@@ -60,8 +59,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
6059
self.batch_progress = BatchProgress()
6160
self.scheduler_progress = SchedulerProgress()
6261

63-
self.accumulated_loss = TensorRunningAccum(window_length=20)
64-
self.running_loss = TensorRunningAccum(window_length=20)
6562
self.optimizer_loop = OptimizerLoop()
6663
self.manual_loop = ManualOptimization()
6764

@@ -294,11 +291,6 @@ def teardown(self) -> None:
294291
self._results.cpu()
295292
self.optimizer_loop.teardown()
296293
self.manual_loop.teardown()
297-
# release memory
298-
if self.accumulated_loss.memory is not None:
299-
self.accumulated_loss.memory = self.accumulated_loss.memory.cpu()
300-
if self.running_loss.memory is not None:
301-
self.running_loss.memory = self.running_loss.memory.cpu()
302294
self.val_loop.teardown()
303295

304296
def on_save_checkpoint(self) -> Dict:
@@ -554,21 +546,6 @@ def _build_kwargs(self, kwargs: OrderedDict, batch: Any, batch_idx: int) -> Orde
554546
kwargs["batch_idx"] = batch_idx
555547
return kwargs
556548

557-
def _update_running_loss(self, current_loss: Tensor) -> None:
558-
"""Updates the running loss value with the current value."""
559-
if self.trainer.lightning_module.automatic_optimization:
560-
# track total loss for logging (avoid mem leaks)
561-
self.accumulated_loss.append(current_loss)
562-
563-
accumulated_loss = self.accumulated_loss.mean()
564-
565-
if accumulated_loss is not None:
566-
# calculate running loss for display
567-
self.running_loss.append(self.accumulated_loss.mean() * self.trainer.accumulate_grad_batches)
568-
569-
# reset for next set of accumulated grads
570-
self.accumulated_loss.reset()
571-
572549

573550
def _convert_optim_dict(outs: Dict[int, Dict[str, Any]], num_optimizers: int) -> List[Optional[Dict[str, Any]]]:
574551
"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element.

src/pytorch_lightning/loops/fit_loop.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pytorch_lightning.loops.utilities import _is_max_limit_reached, _set_sampler_epoch
2424
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
2525
from pytorch_lightning.trainer.progress import Progress
26-
from pytorch_lightning.trainer.supporters import CombinedLoader, TensorRunningAccum
26+
from pytorch_lightning.trainer.supporters import CombinedLoader
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.fetching import (
2929
AbstractDataFetcher,
@@ -104,11 +104,6 @@ def max_steps(self, value: int) -> None:
104104
)
105105
self.epoch_loop.max_steps = value
106106

107-
@property
108-
def running_loss(self) -> TensorRunningAccum:
109-
"""Returns the running loss."""
110-
return self.epoch_loop.running_loss
111-
112107
@Loop.restarting.setter
113108
def restarting(self, restarting: bool) -> None:
114109
# if the last epoch completely finished, we are not actually restarting
@@ -233,9 +228,6 @@ def on_advance_start(self) -> None:
233228
# changing gradient according accumulation_scheduler
234229
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)
235230

236-
# stores accumulated grad fractions per batch
237-
self.epoch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)
238-
239231
self.epoch_progress.increment_ready()
240232

241233
self.trainer._logger_connector.on_epoch_start()

src/pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -242,12 +242,6 @@ def _run_optimization(self, kwargs: OrderedDict, optimizer: torch.optim.Optimize
242242

243243
result = closure.consume_result()
244244

245-
if result.loss is not None:
246-
# if no result, user decided to skip optimization
247-
# otherwise update running loss + reset accumulated loss
248-
# TODO: find proper way to handle updating running loss
249-
self.trainer.fit_loop.epoch_loop._update_running_loss(result.loss)
250-
251245
# untoggle model params
252246
self._run_optimization_end(opt_idx)
253247
return result

src/pytorch_lightning/trainer/supporters.py

Lines changed: 0 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import torch
2020
from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections
21-
from torch import Tensor
2221
from torch.utils.data import Dataset
2322
from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, DataLoader
2423
from torch.utils.data.dataset import IterableDataset
@@ -33,83 +32,6 @@
3332
from pytorch_lightning.utilities.imports import _fault_tolerant_training
3433

3534

36-
class TensorRunningAccum:
37-
"""Tracks a running accumulation values (min, max, mean) without graph references.
38-
39-
Examples:
40-
>>> accum = TensorRunningAccum(5)
41-
>>> accum.last(), accum.mean()
42-
(None, None)
43-
>>> accum.append(torch.tensor(1.5))
44-
>>> accum.last(), accum.mean()
45-
(tensor(1.5000), tensor(1.5000))
46-
>>> accum.append(torch.tensor(2.5))
47-
>>> accum.last(), accum.mean()
48-
(tensor(2.5000), tensor(2.))
49-
>>> accum.reset()
50-
>>> _= [accum.append(torch.tensor(i)) for i in range(13)]
51-
>>> accum.last(), accum.mean(), accum.min(), accum.max()
52-
(tensor(12.), tensor(10.), tensor(8.), tensor(12.))
53-
"""
54-
55-
def __init__(self, window_length: int):
56-
self.window_length = window_length
57-
self.reset(window_length)
58-
59-
def reset(self, window_length: Optional[int] = None) -> None:
60-
"""Empty the accumulator."""
61-
if window_length is not None:
62-
self.window_length = window_length
63-
self.memory: Optional[Tensor] = None
64-
self.current_idx: int = 0
65-
self.last_idx: Optional[int] = None
66-
self.rotated: bool = False
67-
68-
def last(self) -> Optional[Tensor]:
69-
"""Get the last added element."""
70-
if self.last_idx is not None:
71-
assert isinstance(self.memory, Tensor)
72-
return self.memory[self.last_idx].float()
73-
74-
def append(self, x: Tensor) -> None:
75-
"""Add an element to the accumulator."""
76-
if self.memory is None:
77-
# tradeoff memory for speed by keeping the memory on device
78-
self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype)
79-
80-
# store without grads
81-
with torch.no_grad():
82-
self.memory[self.current_idx] = x
83-
self.last_idx = self.current_idx
84-
85-
# increase index
86-
self.current_idx += 1
87-
88-
# reset index when hit limit of tensor
89-
self.current_idx = self.current_idx % self.window_length
90-
if self.current_idx == 0:
91-
self.rotated = True
92-
93-
def mean(self) -> Optional[Tensor]:
94-
"""Get mean value from stored elements."""
95-
return self._agg_memory("mean")
96-
97-
def max(self) -> Optional[Tensor]:
98-
"""Get maximal value from stored elements."""
99-
return self._agg_memory("max")
100-
101-
def min(self) -> Optional[Tensor]:
102-
"""Get minimal value from stored elements."""
103-
return self._agg_memory("min")
104-
105-
def _agg_memory(self, how: str) -> Optional[Tensor]:
106-
if self.last_idx is not None:
107-
assert isinstance(self.memory, Tensor)
108-
if self.rotated:
109-
return getattr(self.memory.float(), how)()
110-
return getattr(self.memory[: self.current_idx].float(), how)()
111-
112-
11335
@dataclass
11436
class SharedCycleIteratorState:
11537
"""A state shared between all CycleIterators in a CombinedLoader.

src/pytorch_lightning/tuner/lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ def on_train_batch_end(
379379
if self.progress_bar:
380380
self.progress_bar.update()
381381

382-
loss_tensor = trainer.fit_loop.running_loss.last()
382+
loss_tensor = outputs if isinstance(outputs, torch.Tensor) else outputs["loss"]
383383
assert loss_tensor is not None
384384
current_loss = loss_tensor.item()
385385
current_step = trainer.global_step

tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ class MockedProgressBar(RichProgressBar):
353353
def get_metrics(self, trainer, pl_module):
354354
items = super().get_metrics(trainer, model)
355355
del items["v_num"]
356-
del items["loss"]
357356
# this is equivalent to mocking `set_postfix` as this method gets called every time
358357
self.calls[trainer.state.fn].append(
359358
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,6 @@ def get_metrics(self, trainer: Trainer, model: LightningModule):
660660
model = BoringModel()
661661
trainer.fit(model)
662662
standard_metrics = progress_bar.get_metrics(trainer, model)
663-
assert "loss" in standard_metrics.keys()
664663
assert "v_num" not in standard_metrics.keys()
665664

666665

@@ -673,7 +672,6 @@ class MockedProgressBar(TQDMProgressBar):
673672
def get_metrics(self, trainer, pl_module):
674673
items = super().get_metrics(trainer, model)
675674
del items["v_num"]
676-
del items["loss"]
677675
# this is equivalent to mocking `set_postfix` as this method gets called every time
678676
self.calls[trainer.state.fn].append(
679677
(trainer.state.stage, trainer.current_epoch, trainer.global_step, items)

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -352,21 +352,6 @@ def test_epoch_end(self, outputs):
352352
trainer.test(model)
353353

354354

355-
def test_logging_to_progress_bar_with_reserved_key(tmpdir):
356-
"""Test that logging a metric with a reserved name to the progress bar raises a warning."""
357-
358-
class TestModel(BoringModel):
359-
def training_step(self, *args, **kwargs):
360-
output = super().training_step(*args, **kwargs)
361-
self.log("loss", output["loss"], prog_bar=True)
362-
return output
363-
364-
model = TestModel()
365-
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
366-
with pytest.warns(UserWarning, match="The progress bar already tracks a metric with the .* 'loss'"):
367-
trainer.fit(model)
368-
369-
370355
@pytest.mark.parametrize("add_dataloader_idx", [False, True])
371356
def test_auto_add_dataloader_idx(tmpdir, add_dataloader_idx):
372357
"""test that auto_add_dataloader_idx argument works."""

0 commit comments

Comments
 (0)