Skip to content

Commit 1607a33

Browse files
SkafteNickiBorda
authored andcommitted
Fix TQDM progress bar showing the wrong total when using a finite and iterable dataloader (#21147)
* fix implementation reset * add testing * changelog (cherry picked from commit 14a57c7)
1 parent 8bd363c commit 1607a33

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
2222

2323

24+
- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
25+
26+
2427
## [2.5.4] - 2025-08-29
2528

2629
### Fixed

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ def on_train_start(self, *_: Any) -> None:
265265
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
266266
if self._leave:
267267
self.train_progress_bar = self.init_train_tqdm()
268-
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
268+
total = convert_inf(self.total_train_batches)
269+
self.train_progress_bar.reset()
270+
self.train_progress_bar.total = total
269271
self.train_progress_bar.initial = 0
270272
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
271273

@@ -306,7 +308,9 @@ def on_validation_batch_start(
306308
if not self.has_dataloader_changed(dataloader_idx):
307309
return
308310

309-
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
311+
total = convert_inf(self.total_val_batches_current_dataloader)
312+
self.val_progress_bar.reset()
313+
self.val_progress_bar.total = total
310314
self.val_progress_bar.initial = 0
311315
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
312316
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
@@ -348,7 +352,9 @@ def on_test_batch_start(
348352
if not self.has_dataloader_changed(dataloader_idx):
349353
return
350354

351-
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
355+
total = convert_inf(self.total_test_batches_current_dataloader)
356+
self.test_progress_bar.reset()
357+
self.test_progress_bar.total = total
352358
self.test_progress_bar.initial = 0
353359
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
354360

@@ -387,7 +393,9 @@ def on_predict_batch_start(
387393
if not self.has_dataloader_changed(dataloader_idx):
388394
return
389395

390-
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
396+
total = convert_inf(self.total_predict_batches_current_dataloader)
397+
self.predict_progress_bar.reset()
398+
self.predict_progress_bar.total = total
391399
self.predict_progress_bar.initial = 0
392400
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
393401

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,3 +801,50 @@ def test_tqdm_leave(leave, tmp_path):
801801
)
802802
trainer.fit(model)
803803
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)
804+
805+
806+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
807+
def test_tqdm_progress_bar_reset_behavior(tmp_path):
808+
"""Test that progress bars call reset() without parameters and set total separately."""
809+
model = BoringModel()
810+
811+
class ResetTrackingTqdm(MockTqdm):
812+
def __init__(self, *args, **kwargs):
813+
super().__init__(*args, **kwargs)
814+
self.reset_calls_with_params = []
815+
816+
def reset(self, total=None):
817+
self.reset_calls_with_params.append(total)
818+
super().reset(total)
819+
820+
trainer = Trainer(
821+
default_root_dir=tmp_path,
822+
limit_train_batches=2,
823+
limit_val_batches=2,
824+
max_epochs=1,
825+
logger=False,
826+
enable_checkpointing=False,
827+
)
828+
829+
pbar = trainer.progress_bar_callback
830+
831+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", ResetTrackingTqdm):
832+
trainer.fit(model)
833+
834+
train_bar = pbar.train_progress_bar
835+
assert None in train_bar.reset_calls_with_params, (
836+
f"train reset() should be called without parameters, got calls: {train_bar.reset_calls_with_params}"
837+
)
838+
# Verify that total was set separately to the expected value
839+
assert 2 in train_bar.total_values, (
840+
f"train total should be set to 2 after reset(), got total_values: {train_bar.total_values}"
841+
)
842+
# Verify that validation progress bar reset() was called without parameters
843+
val_bar = pbar.val_progress_bar
844+
assert None in val_bar.reset_calls_with_params, (
845+
f"validation reset() should be called without parameters, got calls: {val_bar.reset_calls_with_params}"
846+
)
847+
# Verify that total was set separately to the expected value
848+
assert 2 in val_bar.total_values, (
849+
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
850+
)

0 commit comments

Comments
 (0)