Skip to content

Commit 14a57c7

Browse files
authored
Fix TQDM progress bar showing the wrong total when using a finite and iterable dataloader (#21147)
* fix implementation reset * add testing * changelog
1 parent 630db82 commit 14a57c7

File tree

3 files changed

+61
-4
lines changed

3 files changed

+61
-4
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3434
- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
3535

3636

37+
- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
38+
3739
---
3840

3941

src/lightning/pytorch/callbacks/progress/tqdm_progress.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,9 @@ def on_train_start(self, *_: Any) -> None:
265265
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
266266
if self._leave:
267267
self.train_progress_bar = self.init_train_tqdm()
268-
self.train_progress_bar.reset(convert_inf(self.total_train_batches))
268+
total = convert_inf(self.total_train_batches)
269+
self.train_progress_bar.reset()
270+
self.train_progress_bar.total = total
269271
self.train_progress_bar.initial = 0
270272
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
271273

@@ -306,7 +308,9 @@ def on_validation_batch_start(
306308
if not self.has_dataloader_changed(dataloader_idx):
307309
return
308310

309-
self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
311+
total = convert_inf(self.total_val_batches_current_dataloader)
312+
self.val_progress_bar.reset()
313+
self.val_progress_bar.total = total
310314
self.val_progress_bar.initial = 0
311315
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
312316
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
@@ -348,7 +352,9 @@ def on_test_batch_start(
348352
if not self.has_dataloader_changed(dataloader_idx):
349353
return
350354

351-
self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
355+
total = convert_inf(self.total_test_batches_current_dataloader)
356+
self.test_progress_bar.reset()
357+
self.test_progress_bar.total = total
352358
self.test_progress_bar.initial = 0
353359
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
354360

@@ -387,7 +393,9 @@ def on_predict_batch_start(
387393
if not self.has_dataloader_changed(dataloader_idx):
388394
return
389395

390-
self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
396+
total = convert_inf(self.total_predict_batches_current_dataloader)
397+
self.predict_progress_bar.reset()
398+
self.predict_progress_bar.total = total
391399
self.predict_progress_bar.initial = 0
392400
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
393401

tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,3 +812,50 @@ def test_tqdm_leave(leave, tmp_path):
812812
)
813813
trainer.fit(model)
814814
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)
815+
816+
817+
@patch("lightning.pytorch.trainer.connectors.callback_connector._RICH_AVAILABLE", False)
818+
def test_tqdm_progress_bar_reset_behavior(tmp_path):
819+
"""Test that progress bars call reset() without parameters and set total separately."""
820+
model = BoringModel()
821+
822+
class ResetTrackingTqdm(MockTqdm):
823+
def __init__(self, *args, **kwargs):
824+
super().__init__(*args, **kwargs)
825+
self.reset_calls_with_params = []
826+
827+
def reset(self, total=None):
828+
self.reset_calls_with_params.append(total)
829+
super().reset(total)
830+
831+
trainer = Trainer(
832+
default_root_dir=tmp_path,
833+
limit_train_batches=2,
834+
limit_val_batches=2,
835+
max_epochs=1,
836+
logger=False,
837+
enable_checkpointing=False,
838+
)
839+
840+
pbar = trainer.progress_bar_callback
841+
842+
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", ResetTrackingTqdm):
843+
trainer.fit(model)
844+
845+
train_bar = pbar.train_progress_bar
846+
assert None in train_bar.reset_calls_with_params, (
847+
f"train reset() should be called without parameters, got calls: {train_bar.reset_calls_with_params}"
848+
)
849+
# Verify that total was set separately to the expected value
850+
assert 2 in train_bar.total_values, (
851+
f"train total should be set to 2 after reset(), got total_values: {train_bar.total_values}"
852+
)
853+
# Verify that validation progress bar reset() was called without parameters
854+
val_bar = pbar.val_progress_bar
855+
assert None in val_bar.reset_calls_with_params, (
856+
f"validation reset() should be called without parameters, got calls: {val_bar.reset_calls_with_params}"
857+
)
858+
# Verify that total was set separately to the expected value
859+
assert 2 in val_bar.total_values, (
860+
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
861+
)

0 commit comments

Comments
 (0)