Skip to content
2 changes: 1 addition & 1 deletion docs/source-pytorch/common/progress_bar.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ The :class:`~pytorch_lightning.callbacks.TQDMProgressBar` uses the `tqdm <https:
It prints to ``stdout`` and shows up to four different bars:

- **sanity check progress:** the progress during the sanity check run
- **main progress:** shows training + validation progress combined. It also accounts for multiple validation runs during training when :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
- **train progress:** shows the training progress. It will pause if validation starts and will resume when it ends, and also accounts for multiple validation runs during training when :paramref:`~pytorch_lightning.trainer.trainer.Trainer.val_check_interval` is used.
- **validation progress:** only visible during validation; shows total progress over all validation datasets.
- **test progress:** only active when testing; shows total progress over all test datasets.

Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The strategy selected by `accelerator="hpu"` now defaults to `find_unused_parameters=False` ([#16611](https://github.com/Lightning-AI/lightning/pull/16611))

- The main progress bar displayed during training no longer includes the combined progress for validation ([#16695](https://github.com/Lightning-AI/lightning/pull/16695))

- Renamed `TQDMProgressBar.main_progress_bar` to `TQDMProgressBar.train_progress_bar` ([#16695](https://github.com/Lightning-AI/lightning/pull/16695))


### Deprecated

Expand Down
23 changes: 1 addition & 22 deletions src/lightning/pytorch/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,6 @@ def total_val_batches(self) -> Union[int, float]:
"""
return sum(self.trainer.num_val_batches) if self.trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0

@property
def total_batches_current_epoch(self) -> Union[int, float]:
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
assert self._trainer is not None

if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_check_batch = self.trainer.val_check_batch
if self.trainer.check_val_every_n_epoch is None:
train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1
val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - (
train_batches_processed // val_check_batch
)
else:
val_checks_per_epoch = total_train_batches // val_check_batch

total_val_batches = total_val_batches * val_checks_per_epoch

return total_train_batches + total_val_batches

def has_dataloader_changed(self, dataloader_idx: int) -> bool:
old_dataloader_idx = self._current_eval_dataloader_idx
self._current_eval_dataloader_idx = dataloader_idx
Expand All @@ -208,7 +187,7 @@ def enable(self) -> None:

The :class:`~lightning.pytorch.trainer.trainer.Trainer` will call this in e.g. pre-training
routines like the :ref:`learning rate finder <advanced/training_tricks:Learning Rate Finder>`.
to temporarily enable and disable the main progress bar.
to temporarily enable and disable the training progress bar.
"""
raise NotImplementedError

Expand Down
27 changes: 12 additions & 15 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def render(self, task: "Task") -> Text:
if (
self._trainer.state.fn != "fit"
or self._trainer.sanity_checking
or self._trainer.progress_bar_callback.main_progress_bar_id != task.id
or self._trainer.progress_bar_callback.train_progress_bar_id != task.id
):
return Text()
if self._trainer.training and task.id not in self._tasks:
Expand Down Expand Up @@ -256,7 +256,7 @@ def __init__(
self._console_kwargs = console_kwargs or {}
self._enabled: bool = True
self.progress: Optional[CustomProgress] = None
self.main_progress_bar_id: Optional["TaskID"]
self.train_progress_bar_id: Optional["TaskID"]
self.val_sanity_progress_bar_id: Optional["TaskID"] = None
self.val_progress_bar_id: Optional["TaskID"]
self.test_progress_bar_id: Optional["TaskID"]
Expand All @@ -280,10 +280,10 @@ def is_disabled(self) -> bool:
return not self.is_enabled

@property
def main_progress_bar(self) -> Task:
def train_progress_bar(self) -> Task:
assert self.progress is not None
assert self.main_progress_bar_id is not None
return self.progress.tasks[self.main_progress_bar_id]
assert self.train_progress_bar_id is not None
return self.progress.tasks[self.train_progress_bar_id]

@property
def val_sanity_check_bar(self) -> Task:
Expand Down Expand Up @@ -362,18 +362,18 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_disabled:
return
total_batches = self.total_batches_current_epoch
total_batches = self.total_train_batches
train_description = self._get_train_description(trainer.current_epoch)

if self.main_progress_bar_id is not None and self._leave:
if self.train_progress_bar_id is not None and self._leave:
self._stop_progress()
self._init_progress(trainer)
if self.progress is not None:
if self.main_progress_bar_id is None:
self.main_progress_bar_id = self._add_task(total_batches, train_description)
if self.train_progress_bar_id is None:
self.train_progress_bar_id = self._add_task(total_batches, train_description)
else:
self.progress.reset(
self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
self.train_progress_bar_id, total=total_batches, description=train_description, visible=True
)

self.refresh()
Expand Down Expand Up @@ -470,7 +470,7 @@ def on_predict_batch_start(
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.train_progress_bar_id, self.train_batch_idx)
self._update_metrics(trainer, pl_module)
self.refresh()

Expand All @@ -491,9 +491,6 @@ def on_validation_batch_end(
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.val_progress_bar_id, self.val_batch_idx)
self.refresh()

Expand Down Expand Up @@ -544,7 +541,7 @@ def _stop_progress(self) -> None:
self._progress_stopped = True

def _reset_progress_bar_ids(self) -> None:
self.main_progress_bar_id = None
self.train_progress_bar_id = None
self.val_sanity_progress_bar_id = None
self.val_progress_bar_id = None
self.test_progress_bar_id = None
Expand Down
61 changes: 28 additions & 33 deletions src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ class TQDMProgressBar(ProgressBarBase):
:mod:`tqdm` package and shows up to four different bars:

- **sanity check progress:** the progress during the sanity check run
- **main progress:** shows training + validation progress combined. It also accounts for
multiple validation runs during training when
- **train progress:** shows the training progress. It will pause if validation starts and will resume
when it ends, and also accounts for multiple validation runs during training when
:paramref:`~lightning.pytorch.trainer.trainer.Trainer.val_check_interval` is used.
- **validation progress:** only visible during validation;
shows total progress over all validation datasets.
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0):
self._refresh_rate = self._resolve_refresh_rate(refresh_rate)
self._process_position = process_position
self._enabled = True
self._main_progress_bar: Optional[_tqdm] = None
self._train_progress_bar: Optional[_tqdm] = None
self._val_progress_bar: Optional[_tqdm] = None
self._test_progress_bar: Optional[_tqdm] = None
self._predict_progress_bar: Optional[_tqdm] = None
Expand All @@ -113,14 +113,14 @@ def __getstate__(self) -> Dict:
return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()}

@property
def main_progress_bar(self) -> _tqdm:
if self._main_progress_bar is None:
raise TypeError(f"The `{self.__class__.__name__}._main_progress_bar` reference has not been set yet.")
return self._main_progress_bar
def train_progress_bar(self) -> _tqdm:
if self._train_progress_bar is None:
raise TypeError(f"The `{self.__class__.__name__}._train_progress_bar` reference has not been set yet.")
return self._train_progress_bar

@main_progress_bar.setter
def main_progress_bar(self, bar: _tqdm) -> None:
self._main_progress_bar = bar
@train_progress_bar.setter
def train_progress_bar(self, bar: _tqdm) -> None:
self._train_progress_bar = bar

@property
def val_progress_bar(self) -> _tqdm:
Expand Down Expand Up @@ -216,7 +216,7 @@ def init_predict_tqdm(self) -> Tqdm:

def init_validation_tqdm(self) -> Tqdm:
"""Override this to customize the tqdm bar for validation."""
# The main progress bar doesn't exist in `trainer.validate()`
# The train progress bar doesn't exist in `trainer.validate()`
has_main_bar = self.trainer.state.fn != "validate"
bar = Tqdm(
desc=self.validation_description,
Expand All @@ -242,33 +242,32 @@ def init_test_tqdm(self) -> Tqdm:

def on_sanity_check_start(self, *_: Any) -> None:
self.val_progress_bar = self.init_sanity_tqdm()
self.main_progress_bar = Tqdm(disable=True) # dummy progress bar
self.train_progress_bar = Tqdm(disable=True) # dummy progress bar

def on_sanity_check_end(self, *_: Any) -> None:
self.main_progress_bar.close()
self.train_progress_bar.close()
self.val_progress_bar.close()

def on_train_start(self, *_: Any) -> None:
self.main_progress_bar = self.init_train_tqdm()
self.train_progress_bar = self.init_train_tqdm()

def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
total_batches = self.total_batches_current_epoch
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.initial = 0
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
current = self.train_batch_idx + self._val_processed
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
current = self.train_batch_idx
if self._should_update(current, self.train_progress_bar.total):
_update_n(self.train_progress_bar, current)
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not self.main_progress_bar.disable:
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
if not self.train_progress_bar.disable:
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))

def on_train_end(self, *_: Any) -> None:
self.main_progress_bar.close()
self.train_progress_bar.close()

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if not trainer.sanity_checking:
Expand All @@ -289,13 +288,9 @@ def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx)

current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
if self._train_progress_bar is not None and trainer.state.fn == "fit":
self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
self.val_progress_bar.close()
self.reset_dataloader_idx_tracker()

Expand Down Expand Up @@ -344,8 +339,8 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule")
def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None:
active_progress_bar = None

if self._main_progress_bar is not None and not self.main_progress_bar.disable:
active_progress_bar = self.main_progress_bar
if self._train_progress_bar is not None and not self.train_progress_bar.disable:
active_progress_bar = self.train_progress_bar
elif self._val_progress_bar is not None and not self.val_progress_bar.disable:
active_progress_bar = self.val_progress_bar
elif self._test_progress_bar is not None and not self.test_progress_bar.disable:
Expand Down
34 changes: 0 additions & 34 deletions tests/tests_pytorch/callbacks/progress/test_base_progress.py

This file was deleted.

Loading