Skip to content
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- The strategy selected by `accelerator="hpu"` now defaults to `find_unused_parameters=False` ([#16611](https://github.com/Lightning-AI/lightning/pull/16611))

- The main progress bar displayed during training no longer includes the combined progress for validation ([#16695](https://github.com/Lightning-AI/lightning/pull/16695))


### Deprecated

Expand Down
21 changes: 0 additions & 21 deletions src/lightning/pytorch/callbacks/progress/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,6 @@ def total_val_batches(self) -> Union[int, float]:
"""
return sum(self.trainer.num_val_batches) if self.trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0

@property
def total_batches_current_epoch(self) -> Union[int, float]:
total_train_batches = self.total_train_batches
total_val_batches = self.total_val_batches
assert self._trainer is not None

if total_train_batches != float("inf") and total_val_batches != float("inf"):
# val can be checked multiple times per epoch
val_check_batch = self.trainer.val_check_batch
if self.trainer.check_val_every_n_epoch is None:
train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1
val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - (
train_batches_processed // val_check_batch
)
else:
val_checks_per_epoch = total_train_batches // val_check_batch

total_val_batches = total_val_batches * val_checks_per_epoch

return total_train_batches + total_val_batches

def has_dataloader_changed(self, dataloader_idx: int) -> bool:
old_dataloader_idx = self._current_eval_dataloader_idx
self._current_eval_dataloader_idx = dataloader_idx
Expand Down
7 changes: 2 additions & 5 deletions src/lightning/pytorch/callbacks/progress/rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod
def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self.is_disabled:
return
total_batches = self.total_batches_current_epoch
total_batches = self.total_train_batches
train_description = self._get_train_description(trainer.current_epoch)

if self.main_progress_bar_id is not None and self._leave:
Expand Down Expand Up @@ -470,7 +470,7 @@ def on_predict_batch_start(
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
) -> None:
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.main_progress_bar_id, self.train_batch_idx)
self._update_metrics(trainer, pl_module)
self.refresh()

Expand All @@ -491,9 +491,6 @@ def on_validation_batch_end(
if trainer.sanity_checking:
self._update(self.val_sanity_progress_bar_id, self.val_batch_idx)
elif self.val_progress_bar_id is not None:
# check to see if we should update the main training progress bar
if self.main_progress_bar_id is not None:
self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed)
self._update(self.val_progress_bar_id, self.val_batch_idx)
self.refresh()

Expand Down
9 changes: 2 additions & 7 deletions src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,13 +252,12 @@ def on_train_start(self, *_: Any) -> None:
self.main_progress_bar = self.init_train_tqdm()

def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
total_batches = self.total_batches_current_epoch
self.main_progress_bar.reset(convert_inf(total_batches))
self.main_progress_bar.reset(convert_inf(self.total_train_batches))
self.main_progress_bar.initial = 0
self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None:
current = self.train_batch_idx + self._val_processed
current = self.train_batch_idx
if self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
Expand Down Expand Up @@ -289,10 +288,6 @@ def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._should_update(self.val_batch_idx, self.val_progress_bar.total):
_update_n(self.val_progress_bar, self.val_batch_idx)

current = self.train_batch_idx + self._val_processed
if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total):
_update_n(self.main_progress_bar, current)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._main_progress_bar is not None and trainer.state.fn == "fit":
self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module))
Expand Down
34 changes: 0 additions & 34 deletions tests/tests_pytorch/callbacks/progress/test_base_progress.py

This file was deleted.

27 changes: 14 additions & 13 deletions tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def predict_dataloader(self):

with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.fit(model)
# 3 for main progress bar and 1 for val progress bar
assert mocked.call_count == 4
# 2 for main progress bar and 1 for val progress bar
assert mocked.call_count == 3

with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked:
trainer.validate(model)
Expand Down Expand Up @@ -214,16 +214,17 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir):
@pytest.mark.parametrize(
"refresh_rate,train_batches,val_batches,expected_call_count",
[
(3, 6, 6, 4 + 3),
(4, 6, 6, 3 + 3),
(7, 6, 6, 2 + 2),
(1, 2, 3, 5 + 4),
# note: there is always one extra update at the very end (+1)
(3, 6, 6, 2 + 2 + 1),
(4, 6, 6, 2 + 2 + 1),
(7, 6, 6, 1 + 1 + 1),
(1, 2, 3, 2 + 3 + 1),
(1, 0, 0, 0 + 0),
(3, 1, 0, 1 + 0),
(3, 1, 1, 1 + 2),
(3, 1, 1, 1 + 1 + 1),
(3, 5, 0, 2 + 0),
(3, 5, 2, 3 + 2),
(6, 5, 2, 2 + 2),
(3, 5, 2, 2 + 1 + 1),
(6, 5, 2, 1 + 1 + 1),
],
)
def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count):
Expand All @@ -246,8 +247,8 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches

if train_batches > 0:
fit_main_bar = trainer.progress_bar_callback.progress.tasks[0]
assert fit_main_bar.completed == train_batches + val_batches
assert fit_main_bar.total == train_batches + val_batches
assert fit_main_bar.completed == train_batches
assert fit_main_bar.total == train_batches
assert fit_main_bar.visible
if val_batches > 0:
fit_val_bar = trainer.progress_bar_callback.progress.tasks[1]
Expand Down Expand Up @@ -294,8 +295,8 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir):
trainer.fit(model)

fit_main_progress_bar = progress_bar.progress.tasks[1]
assert fit_main_progress_bar.completed == 7 + 3 * 4
assert fit_main_progress_bar.total == 7 + 3 * 4
assert fit_main_progress_bar.completed == 7
assert fit_main_progress_bar.total == 7

fit_val_bar = progress_bar.progress.tasks[2]
assert fit_val_bar.completed == 4
Expand Down
53 changes: 22 additions & 31 deletions tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None):
n = trainer.num_training_batches
m = trainer.num_val_batches
assert len(trainer.train_dataloader) == n
# main progress bar should have reached the end (train batches + val batches)
assert pbar.main_progress_bar.total == n + sum(m)
assert pbar.main_progress_bar.n == n + sum(m)
# main progress bar should have reached the end
assert pbar.main_progress_bar.total == n
assert pbar.main_progress_bar.n == n
assert pbar.main_progress_bar.leave

# check val progress bar total
Expand Down Expand Up @@ -214,9 +214,9 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir):
assert 1 == pbar.val_progress_bar.n
assert 1 == pbar.val_progress_bar.total

# the main progress bar should display 2 batches (1 train, 1 val)
assert 2 == pbar.main_progress_bar.total
assert 2 == pbar.main_progress_bar.n
# the main progress bar should display 1 batch
assert 1 == pbar.main_progress_bar.total
assert 1 == pbar.main_progress_bar.n

trainer.validate(model)

Expand Down Expand Up @@ -266,24 +266,18 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal
assert trainer.progress_bar_callback.refresh_rate == refresh_rate

trainer.fit(model)
assert (
pbar.train_batches_seen + pbar.val_batches_seen
== 3 * pbar.main_progress_bar.total + trainer.num_sanity_val_steps
)
assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total
assert pbar.val_batches_seen == 3 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps
assert pbar.test_batches_seen == 0

trainer.validate(model)
assert (
pbar.train_batches_seen + pbar.val_batches_seen
== 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps
)
assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total
assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps
assert pbar.test_batches_seen == 0

trainer.test(model)
assert (
pbar.train_batches_seen + pbar.val_batches_seen
== 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps
)
assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total
assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps
assert pbar.test_batches_seen == pbar.test_progress_bar.total


Expand Down Expand Up @@ -345,13 +339,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir):
@pytest.mark.parametrize(
"train_batches,val_batches,refresh_rate,train_updates,val_updates",
[
[2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]],
[2, 3, 1, [0, 1, 2], [0, 1, 2, 3]],
[0, 0, 3, None, None],
[1, 0, 3, [0, 1], None],
[1, 1, 3, [0, 2], [0, 1]],
[1, 1, 3, [0, 1], [0, 1]],
[5, 0, 3, [0, 3, 5], None],
[5, 2, 3, [0, 3, 6, 7], [0, 2]],
[5, 2, 6, [0, 6, 7], [0, 2]],
[5, 2, 3, [0, 3, 5], [0, 2]],
[5, 2, 6, [0, 5], [0, 2]],
],
)
def test_main_progress_bar_update_amount(
Expand Down Expand Up @@ -562,7 +556,7 @@ def test_tqdm_progress_bar_can_be_pickled():

@pytest.mark.parametrize(
["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"],
[(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])],
[(4, [0, 3, 6, 7], [0, 3, 6, 7]), (0.5, [0, 3, 6, 7], [0, 3, 6, 7])],
)
def test_progress_bar_max_val_check_interval(
tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates
Expand Down Expand Up @@ -590,14 +584,13 @@ def test_progress_bar_max_val_check_interval(
max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval
)
assert trainer.val_check_batch == val_check_batch
val_checks_per_epoch = math.ceil(limit_batches // val_check_batch)
math.ceil(limit_batches // val_check_batch)
pbar_callback = trainer.progress_bar_callback
total_val_batches = limit_batches * val_checks_per_epoch

assert pbar_callback.val_progress_bar.n == limit_batches
assert pbar_callback.val_progress_bar.total == limit_batches
assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches
assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches
assert pbar_callback.main_progress_bar.n == limit_batches
assert pbar_callback.main_progress_bar.total == limit_batches
assert pbar_callback.is_enabled


Expand Down Expand Up @@ -629,16 +622,14 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval):
total_train_batches = total_train_samples // (train_batch_size * world_size)
val_check_batch = max(1, int(total_train_batches * val_check_interval))
assert trainer.val_check_batch == val_check_batch
val_checks_per_epoch = total_train_batches / val_check_batch
total_val_batches = total_val_samples // (val_batch_size * world_size)
pbar_callback = trainer.progress_bar_callback

if trainer.is_global_zero:
assert pbar_callback.val_progress_bar.n == total_val_batches
assert pbar_callback.val_progress_bar.total == total_val_batches
total_val_batches = total_val_batches * val_checks_per_epoch
assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size
assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size
assert pbar_callback.main_progress_bar.n == total_train_batches // world_size
assert pbar_callback.main_progress_bar.total == total_train_batches // world_size
assert pbar_callback.is_enabled


Expand Down