diff --git a/docs/source-pytorch/common/progress_bar.rst b/docs/source-pytorch/common/progress_bar.rst index d00c716bb83cf..e03c6491edc3e 100644 --- a/docs/source-pytorch/common/progress_bar.rst +++ b/docs/source-pytorch/common/progress_bar.rst @@ -22,7 +22,7 @@ The :class:`~pytorch_lightning.callbacks.TQDMProgressBar` uses the `tqdm Union[int, float]: """ return sum(self.trainer.num_val_batches) if self.trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0 - @property - def total_batches_current_epoch(self) -> Union[int, float]: - total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches - assert self._trainer is not None - - if total_train_batches != float("inf") and total_val_batches != float("inf"): - # val can be checked multiple times per epoch - val_check_batch = self.trainer.val_check_batch - if self.trainer.check_val_every_n_epoch is None: - train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1 - val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - ( - train_batches_processed // val_check_batch - ) - else: - val_checks_per_epoch = total_train_batches // val_check_batch - - total_val_batches = total_val_batches * val_checks_per_epoch - - return total_train_batches + total_val_batches - def has_dataloader_changed(self, dataloader_idx: int) -> bool: old_dataloader_idx = self._current_eval_dataloader_idx self._current_eval_dataloader_idx = dataloader_idx @@ -208,7 +187,7 @@ def enable(self) -> None: The :class:`~lightning.pytorch.trainer.trainer.Trainer` will call this in e.g. pre-training routines like the :ref:`learning rate finder `. - to temporarily enable and disable the main progress bar. + to temporarily enable and disable the training progress bar. """ raise NotImplementedError diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index de5b83cb545c5..0e7871ab16f9f 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -156,7 +156,7 @@ def render(self, task: "Task") -> Text: if ( self._trainer.state.fn != "fit" or self._trainer.sanity_checking - or self._trainer.progress_bar_callback.main_progress_bar_id != task.id + or self._trainer.progress_bar_callback.train_progress_bar_id != task.id ): return Text() if self._trainer.training and task.id not in self._tasks: @@ -256,7 +256,7 @@ def __init__( self._console_kwargs = console_kwargs or {} self._enabled: bool = True self.progress: Optional[CustomProgress] = None - self.main_progress_bar_id: Optional["TaskID"] + self.train_progress_bar_id: Optional["TaskID"] self.val_sanity_progress_bar_id: Optional["TaskID"] = None self.val_progress_bar_id: Optional["TaskID"] self.test_progress_bar_id: Optional["TaskID"] @@ -280,10 +280,10 @@ def is_disabled(self) -> bool: return not self.is_enabled @property - def main_progress_bar(self) -> Task: + def train_progress_bar(self) -> Task: assert self.progress is not None - assert self.main_progress_bar_id is not None - return self.progress.tasks[self.main_progress_bar_id] + assert self.train_progress_bar_id is not None + return self.progress.tasks[self.train_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: @@ -362,18 +362,18 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.is_disabled: return - total_batches = self.total_batches_current_epoch + total_batches = self.total_train_batches train_description = self._get_train_description(trainer.current_epoch) - if self.main_progress_bar_id is not None and self._leave: + if self.train_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer) if self.progress is not None: - if self.main_progress_bar_id is None: - self.main_progress_bar_id = self._add_task(total_batches, train_description) + if self.train_progress_bar_id is None: + self.train_progress_bar_id = self._add_task(total_batches, train_description) else: self.progress.reset( - self.main_progress_bar_id, total=total_batches, description=train_description, visible=True + self.train_progress_bar_id, total=total_batches, description=train_description, visible=True ) self.refresh() @@ -470,7 +470,7 @@ def on_predict_batch_start( def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: - self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) + self._update(self.train_progress_bar_id, self.train_batch_idx) self._update_metrics(trainer, pl_module) self.refresh() @@ -491,9 +491,6 @@ def on_validation_batch_end( if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) elif self.val_progress_bar_id is not None: - # check to see if we should update the main training progress bar - if self.main_progress_bar_id is not None: - self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) self._update(self.val_progress_bar_id, self.val_batch_idx) self.refresh() @@ -544,7 +541,7 @@ def _stop_progress(self) -> None: self._progress_stopped = True def _reset_progress_bar_ids(self) -> None: - self.main_progress_bar_id = None + self.train_progress_bar_id = None self.val_sanity_progress_bar_id = None self.val_progress_bar_id = None self.test_progress_bar_id = None diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 5af994c296c2b..bea7251ad168a 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -63,8 +63,8 @@ class TQDMProgressBar(ProgressBarBase): :mod:`tqdm` package and shows up to four different bars: - **sanity check progress:** the progress during the sanity check run - - **main progress:** shows training + validation progress combined. It also accounts for - multiple validation runs during training when + - **train progress:** shows the training progress. It will pause if validation starts and will resume + when it ends, and also accounts for multiple validation runs during training when :paramref:`~lightning.pytorch.trainer.trainer.Trainer.val_check_interval` is used. - **validation progress:** only visible during validation; shows total progress over all validation datasets. @@ -103,7 +103,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self._refresh_rate = self._resolve_refresh_rate(refresh_rate) self._process_position = process_position self._enabled = True - self._main_progress_bar: Optional[_tqdm] = None + self._train_progress_bar: Optional[_tqdm] = None self._val_progress_bar: Optional[_tqdm] = None self._test_progress_bar: Optional[_tqdm] = None self._predict_progress_bar: Optional[_tqdm] = None @@ -113,14 +113,14 @@ def __getstate__(self) -> Dict: return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} @property - def main_progress_bar(self) -> _tqdm: - if self._main_progress_bar is None: - raise TypeError(f"The `{self.__class__.__name__}._main_progress_bar` reference has not been set yet.") - return self._main_progress_bar + def train_progress_bar(self) -> _tqdm: + if self._train_progress_bar is None: + raise TypeError(f"The `{self.__class__.__name__}._train_progress_bar` reference has not been set yet.") + return self._train_progress_bar - @main_progress_bar.setter - def main_progress_bar(self, bar: _tqdm) -> None: - self._main_progress_bar = bar + @train_progress_bar.setter + def train_progress_bar(self, bar: _tqdm) -> None: + self._train_progress_bar = bar @property def val_progress_bar(self) -> _tqdm: @@ -216,7 +216,7 @@ def init_predict_tqdm(self) -> Tqdm: def init_validation_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for validation.""" - # The main progress bar doesn't exist in `trainer.validate()` + # The train progress bar doesn't exist in `trainer.validate()` has_main_bar = self.trainer.state.fn != "validate" bar = Tqdm( desc=self.validation_description, @@ -242,33 +242,32 @@ def init_test_tqdm(self) -> Tqdm: def on_sanity_check_start(self, *_: Any) -> None: self.val_progress_bar = self.init_sanity_tqdm() - self.main_progress_bar = Tqdm(disable=True) # dummy progress bar + self.train_progress_bar = Tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, *_: Any) -> None: - self.main_progress_bar.close() + self.train_progress_bar.close() self.val_progress_bar.close() def on_train_start(self, *_: Any) -> None: - self.main_progress_bar = self.init_train_tqdm() + self.train_progress_bar = self.init_train_tqdm() def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: - total_batches = self.total_batches_current_epoch - self.main_progress_bar.reset(convert_inf(total_batches)) - self.main_progress_bar.initial = 0 - self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + self.train_progress_bar.reset(convert_inf(self.total_train_batches)) + self.train_progress_bar.initial = 0 + self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: - current = self.train_batch_idx + self._val_processed - if self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) - self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + current = self.train_batch_idx + if self._should_update(current, self.train_progress_bar.total): + _update_n(self.train_progress_bar, current) + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not self.main_progress_bar.disable: - self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + if not self.train_progress_bar.disable: + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_end(self, *_: Any) -> None: - self.main_progress_bar.close() + self.train_progress_bar.close() def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not trainer.sanity_checking: @@ -289,13 +288,9 @@ def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.val_progress_bar.total): _update_n(self.val_progress_bar, self.val_batch_idx) - current = self.train_batch_idx + self._val_processed - if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self._main_progress_bar is not None and trainer.state.fn == "fit": - self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + if self._train_progress_bar is not None and trainer.state.fn == "fit": + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() self.reset_dataloader_idx_tracker() @@ -344,8 +339,8 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: active_progress_bar = None - if self._main_progress_bar is not None and not self.main_progress_bar.disable: - active_progress_bar = self.main_progress_bar + if self._train_progress_bar is not None and not self.train_progress_bar.disable: + active_progress_bar = self.train_progress_bar elif self._val_progress_bar is not None and not self.val_progress_bar.disable: active_progress_bar = self.val_progress_bar elif self._test_progress_bar is not None and not self.test_progress_bar.disable: diff --git a/tests/tests_pytorch/callbacks/progress/test_base_progress.py b/tests/tests_pytorch/callbacks/progress/test_base_progress.py deleted file mode 100644 index 4ece74473b116..0000000000000 --- a/tests/tests_pytorch/callbacks/progress/test_base_progress.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.trainer import Trainer - - -def test_main_progress_bar_with_val_check_interval_int(): - """Test the main progress bar count when val_check_interval=int and check_val_every_n_epoch=None.""" - train_batches = 5 - trainer = Trainer( - limit_train_batches=train_batches, limit_val_batches=10, val_check_interval=3, check_val_every_n_epoch=None - ) - model = BoringModel() - trainer.progress_bar_callback.setup(trainer, model, stage="fit") - trainer.strategy.connect(model) - trainer._data_connector.attach_data(model) - trainer.reset_train_dataloader() - trainer.reset_val_dataloader() - expected = [15, 25, 25, 15] - - for count in expected: - assert trainer.progress_bar_callback.total_batches_current_epoch == count - trainer.fit_loop.epoch_loop.batch_progress.total.ready += train_batches diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 4f5a4ca92c0f3..6800d622e8fd4 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -77,8 +77,8 @@ def predict_dataloader(self): with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked: trainer.fit(model) - # 3 for main progress bar and 1 for val progress bar - assert mocked.call_count == 4 + # 2 for train progress bar and 1 for val progress bar + assert mocked.call_count == 3 with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked: trainer.validate(model) @@ -214,16 +214,17 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir): @pytest.mark.parametrize( "refresh_rate,train_batches,val_batches,expected_call_count", [ - (3, 6, 6, 4 + 3), - (4, 6, 6, 3 + 3), - (7, 6, 6, 2 + 2), - (1, 2, 3, 5 + 4), + # note: there is always one extra update at the very end (+1) + (3, 6, 6, 2 + 2 + 1), + (4, 6, 6, 2 + 2 + 1), + (7, 6, 6, 1 + 1 + 1), + (1, 2, 3, 2 + 3 + 1), (1, 0, 0, 0 + 0), (3, 1, 0, 1 + 0), - (3, 1, 1, 1 + 2), + (3, 1, 1, 1 + 1 + 1), (3, 5, 0, 2 + 0), - (3, 5, 2, 3 + 2), - (6, 5, 2, 2 + 2), + (3, 5, 2, 2 + 1 + 1), + (6, 5, 2, 1 + 1 + 1), ], ) def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count): @@ -246,8 +247,8 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches if train_batches > 0: fit_main_bar = trainer.progress_bar_callback.progress.tasks[0] - assert fit_main_bar.completed == train_batches + val_batches - assert fit_main_bar.total == train_batches + val_batches + assert fit_main_bar.completed == train_batches + assert fit_main_bar.total == train_batches assert fit_main_bar.visible if val_batches > 0: fit_val_bar = trainer.progress_bar_callback.progress.tasks[1] @@ -293,9 +294,9 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir): ) trainer.fit(model) - fit_main_progress_bar = progress_bar.progress.tasks[1] - assert fit_main_progress_bar.completed == 7 + 3 * 4 - assert fit_main_progress_bar.total == 7 + 3 * 4 + fit_train_progress_bar = progress_bar.progress.tasks[1] + assert fit_train_progress_bar.completed == 7 + assert fit_train_progress_bar.total == 7 fit_val_bar = progress_bar.progress.tasks[2] assert fit_val_bar.completed == 4 @@ -334,12 +335,12 @@ def training_step(self, *args, **kwargs): trainer = Trainer(default_root_dir=tmpdir, callbacks=progress_bar, fast_dev_run=True, logger=CSVLogger(tmpdir)) trainer.fit(model) - main_progress_bar_id = progress_bar.main_progress_bar_id + train_progress_bar_id = progress_bar.train_progress_bar_id val_progress_bar_id = progress_bar.val_progress_bar_id rendered = progress_bar.progress.columns[-1]._renderable_cache for key in ("loss", "v_num", "train_loss"): - assert key in rendered[main_progress_bar_id][1] + assert key in rendered[train_progress_bar_id][1] assert key not in rendered[val_progress_bar_id][1] @@ -454,7 +455,7 @@ def test_rich_progress_bar_reset_bars(): assert bar._progress_stopped is False def _set_fake_bar_ids(): - bar.main_progress_bar_id = 0 + bar.train_progress_bar_id = 0 bar.val_sanity_progress_bar_id = 1 bar.val_progress_bar_id = 2 bar.test_progress_bar_id = 3 @@ -469,7 +470,7 @@ def _set_fake_bar_ids(): bar.teardown(Mock(), Mock(), Mock()) # assert all bars are reset - assert bar.main_progress_bar_id is None + assert bar.train_progress_bar_id is None assert bar.val_sanity_progress_bar_id is None assert bar.val_progress_bar_id is None assert bar.test_progress_bar_id is None @@ -506,7 +507,7 @@ def test_rich_progress_bar_disabled(tmpdir): trainer.predict(model) mocked.assert_not_called() - assert bar.main_progress_bar_id is None + assert bar.train_progress_bar_id is None assert bar.val_sanity_progress_bar_id is None assert bar.val_progress_bar_id is None assert bar.test_progress_bar_id is None diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 9605bae1b8441..2181f5a842aac 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -162,10 +162,10 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): n = trainer.num_training_batches m = trainer.num_val_batches assert len(trainer.train_dataloader) == n - # main progress bar should have reached the end (train batches + val batches) - assert pbar.main_progress_bar.total == n + sum(m) - assert pbar.main_progress_bar.n == n + sum(m) - assert pbar.main_progress_bar.leave + # train progress bar should have reached the end + assert pbar.train_progress_bar.total == n + assert pbar.train_progress_bar.n == n + assert pbar.train_progress_bar.leave # check val progress bar total assert pbar.val_progress_bar.total_values == m @@ -214,9 +214,9 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): assert 1 == pbar.val_progress_bar.n assert 1 == pbar.val_progress_bar.total - # the main progress bar should display 2 batches (1 train, 1 val) - assert 2 == pbar.main_progress_bar.total - assert 2 == pbar.main_progress_bar.n + # the train progress bar should display 1 batch + assert 1 == pbar.train_progress_bar.total + assert 1 == pbar.train_progress_bar.n trainer.validate(model) @@ -266,24 +266,18 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) - assert ( - pbar.train_batches_seen + pbar.val_batches_seen - == 3 * pbar.main_progress_bar.total + trainer.num_sanity_val_steps - ) + assert pbar.train_batches_seen == 3 * pbar.train_progress_bar.total + assert pbar.val_batches_seen == 3 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == 0 trainer.validate(model) - assert ( - pbar.train_batches_seen + pbar.val_batches_seen - == 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps - ) + assert pbar.train_batches_seen == 3 * pbar.train_progress_bar.total + assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == 0 trainer.test(model) - assert ( - pbar.train_batches_seen + pbar.val_batches_seen - == 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps - ) + assert pbar.train_batches_seen == 3 * pbar.train_progress_bar.total + assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == pbar.test_progress_bar.total @@ -345,19 +339,19 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): @pytest.mark.parametrize( "train_batches,val_batches,refresh_rate,train_updates,val_updates", [ - [2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]], + [2, 3, 1, [0, 1, 2], [0, 1, 2, 3]], [0, 0, 3, None, None], [1, 0, 3, [0, 1], None], - [1, 1, 3, [0, 2], [0, 1]], + [1, 1, 3, [0, 1], [0, 1]], [5, 0, 3, [0, 3, 5], None], - [5, 2, 3, [0, 3, 6, 7], [0, 2]], - [5, 2, 6, [0, 6, 7], [0, 2]], + [5, 2, 3, [0, 3, 5], [0, 2]], + [5, 2, 6, [0, 5], [0, 2]], ], ) -def test_main_progress_bar_update_amount( +def test_train_progress_bar_update_amount( tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_updates, val_updates ): - """Test that the main progress updates with the correct amount together with the val progress. + """Test that the train progress updates with the correct amount together with the val progress. At the end of the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate. @@ -376,7 +370,7 @@ def test_main_progress_bar_update_amount( with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.fit(model) if train_batches > 0: - assert progress_bar.main_progress_bar.n_values == train_updates + assert progress_bar.train_progress_bar.n_values == train_updates if val_batches > 0: assert progress_bar.val_progress_bar.n_values == val_updates @@ -417,7 +411,7 @@ def training_step(self, batch, batch_idx): torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123) assert trainer.progress_bar_metrics["b"] == 1.0 assert trainer.progress_bar_metrics["c"] == 2.0 - pbar = trainer.progress_bar_callback.main_progress_bar + pbar = trainer.progress_bar_callback.train_progress_bar actual = str(pbar.postfix) assert actual.endswith("a=0.123, b=1.000, c=2.000"), actual @@ -561,11 +555,11 @@ def test_tqdm_progress_bar_can_be_pickled(): @pytest.mark.parametrize( - ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], - [(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])], + ["val_check_interval", "train_progress_bar_updates", "val_progress_bar_updates"], + [(4, [0, 3, 6, 7], [0, 3, 6, 7]), (0.5, [0, 3, 6, 7], [0, 3, 6, 7])], ) def test_progress_bar_max_val_check_interval( - tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates + tmpdir, val_check_interval, train_progress_bar_updates, val_progress_bar_updates ): limit_batches = 7 model = BoringModel() @@ -583,21 +577,20 @@ def test_progress_bar_max_val_check_interval( trainer.fit(model) pbar = trainer.progress_bar_callback - assert pbar.main_progress_bar.n_values == main_progress_bar_updates + assert pbar.train_progress_bar.n_values == train_progress_bar_updates assert pbar.val_progress_bar.n_values == val_progress_bar_updates val_check_batch = ( max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval ) assert trainer.val_check_batch == val_check_batch - val_checks_per_epoch = math.ceil(limit_batches // val_check_batch) + math.ceil(limit_batches // val_check_batch) pbar_callback = trainer.progress_bar_callback - total_val_batches = limit_batches * val_checks_per_epoch assert pbar_callback.val_progress_bar.n == limit_batches assert pbar_callback.val_progress_bar.total == limit_batches - assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches - assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches + assert pbar_callback.train_progress_bar.n == limit_batches + assert pbar_callback.train_progress_bar.total == limit_batches assert pbar_callback.is_enabled @@ -629,16 +622,14 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): total_train_batches = total_train_samples // (train_batch_size * world_size) val_check_batch = max(1, int(total_train_batches * val_check_interval)) assert trainer.val_check_batch == val_check_batch - val_checks_per_epoch = total_train_batches / val_check_batch total_val_batches = total_val_samples // (val_batch_size * world_size) pbar_callback = trainer.progress_bar_callback if trainer.is_global_zero: assert pbar_callback.val_progress_bar.n == total_val_batches assert pbar_callback.val_progress_bar.total == total_val_batches - total_val_batches = total_val_batches * val_checks_per_epoch - assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size - assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size + assert pbar_callback.train_progress_bar.n == total_train_batches // world_size + assert pbar_callback.train_progress_bar.total == total_train_batches // world_size assert pbar_callback.is_enabled