From f784f96770bb076a49d85f9b5c601fbdd7fca570 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 8 Feb 2023 23:47:14 +0100 Subject: [PATCH 01/11] split tqdm bar --- .../pytorch/callbacks/progress/base.py | 21 --------- .../callbacks/progress/tqdm_progress.py | 9 +--- .../progress/test_tqdm_progress_bar.py | 45 ++++++++----------- 3 files changed, 21 insertions(+), 54 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/base.py b/src/lightning/pytorch/callbacks/progress/base.py index 05532b48d4b3c..7f396d154e276 100644 --- a/src/lightning/pytorch/callbacks/progress/base.py +++ b/src/lightning/pytorch/callbacks/progress/base.py @@ -170,27 +170,6 @@ def total_val_batches(self) -> Union[int, float]: """ return sum(self.trainer.num_val_batches) if self.trainer.fit_loop.epoch_loop._should_check_val_epoch() else 0 - @property - def total_batches_current_epoch(self) -> Union[int, float]: - total_train_batches = self.total_train_batches - total_val_batches = self.total_val_batches - assert self._trainer is not None - - if total_train_batches != float("inf") and total_val_batches != float("inf"): - # val can be checked multiple times per epoch - val_check_batch = self.trainer.val_check_batch - if self.trainer.check_val_every_n_epoch is None: - train_batches_processed = self.trainer.fit_loop.total_batch_idx + 1 - val_checks_per_epoch = ((train_batches_processed + total_train_batches) // val_check_batch) - ( - train_batches_processed // val_check_batch - ) - else: - val_checks_per_epoch = total_train_batches // val_check_batch - - total_val_batches = total_val_batches * val_checks_per_epoch - - return total_train_batches + total_val_batches - def has_dataloader_changed(self, dataloader_idx: int) -> bool: old_dataloader_idx = self._current_eval_dataloader_idx self._current_eval_dataloader_idx = dataloader_idx diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 5af994c296c2b..b886a34aae831 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -252,13 +252,12 @@ def on_train_start(self, *_: Any) -> None: self.main_progress_bar = self.init_train_tqdm() def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: - total_batches = self.total_batches_current_epoch - self.main_progress_bar.reset(convert_inf(total_batches)) + self.main_progress_bar.reset(convert_inf(self.total_train_batches)) self.main_progress_bar.initial = 0 self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: - current = self.train_batch_idx + self._val_processed + current = self.train_batch_idx if self._should_update(current, self.main_progress_bar.total): _update_n(self.main_progress_bar, current) self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) @@ -289,10 +288,6 @@ def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: if self._should_update(self.val_batch_idx, self.val_progress_bar.total): _update_n(self.val_progress_bar, self.val_batch_idx) - current = self.train_batch_idx + self._val_processed - if trainer.state.fn == "fit" and self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) - def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self._main_progress_bar is not None and trainer.state.fn == "fit": self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 9605bae1b8441..88ae2b6d4863b 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -162,9 +162,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): n = trainer.num_training_batches m = trainer.num_val_batches assert len(trainer.train_dataloader) == n - # main progress bar should have reached the end (train batches + val batches) - assert pbar.main_progress_bar.total == n + sum(m) - assert pbar.main_progress_bar.n == n + sum(m) + # main progress bar should have reached the end + assert pbar.main_progress_bar.total == n + assert pbar.main_progress_bar.n == n assert pbar.main_progress_bar.leave # check val progress bar total @@ -214,9 +214,9 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): assert 1 == pbar.val_progress_bar.n assert 1 == pbar.val_progress_bar.total - # the main progress bar should display 2 batches (1 train, 1 val) - assert 2 == pbar.main_progress_bar.total - assert 2 == pbar.main_progress_bar.n + # the main progress bar should display 1 batch + assert 1 == pbar.main_progress_bar.total + assert 1 == pbar.main_progress_bar.n trainer.validate(model) @@ -266,24 +266,18 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) - assert ( - pbar.train_batches_seen + pbar.val_batches_seen - == 3 * pbar.main_progress_bar.total + trainer.num_sanity_val_steps - ) + assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total + assert pbar.val_batches_seen == 3 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == 0 trainer.validate(model) - assert ( - pbar.train_batches_seen + pbar.val_batches_seen - == 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps - ) + assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total + assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == 0 trainer.test(model) - assert ( - pbar.train_batches_seen + pbar.val_batches_seen - == 3 * pbar.main_progress_bar.total + pbar.val_progress_bar.total + trainer.num_sanity_val_steps - ) + assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total + assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == pbar.test_progress_bar.total @@ -345,13 +339,13 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): @pytest.mark.parametrize( "train_batches,val_batches,refresh_rate,train_updates,val_updates", [ - [2, 3, 1, [0, 1, 2, 3, 4, 5], [0, 1, 2, 3]], + [2, 3, 1, [0, 1, 2], [0, 1, 2, 3]], [0, 0, 3, None, None], [1, 0, 3, [0, 1], None], - [1, 1, 3, [0, 2], [0, 1]], + [1, 1, 3, [0, 1], [0, 1]], [5, 0, 3, [0, 3, 5], None], - [5, 2, 3, [0, 3, 6, 7], [0, 2]], - [5, 2, 6, [0, 6, 7], [0, 2]], + [5, 2, 3, [0, 3, 5], [0, 2]], + [5, 2, 6, [0, 5], [0, 2]], ], ) def test_main_progress_bar_update_amount( @@ -562,7 +556,7 @@ def test_tqdm_progress_bar_can_be_pickled(): @pytest.mark.parametrize( ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], - [(4, [0, 3, 6, 9, 12, 14], [0, 3, 6, 7]), (0.5, [0, 3, 6, 9, 12, 15, 18, 21], [0, 3, 6, 7])], + [(4, [0, 3, 6, 7], [0, 3, 6, 7]), (0.5, [0, 3, 6, 7], [0, 3, 6, 7])], ) def test_progress_bar_max_val_check_interval( tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates @@ -592,12 +586,11 @@ def test_progress_bar_max_val_check_interval( assert trainer.val_check_batch == val_check_batch val_checks_per_epoch = math.ceil(limit_batches // val_check_batch) pbar_callback = trainer.progress_bar_callback - total_val_batches = limit_batches * val_checks_per_epoch assert pbar_callback.val_progress_bar.n == limit_batches assert pbar_callback.val_progress_bar.total == limit_batches - assert pbar_callback.main_progress_bar.n == limit_batches + total_val_batches - assert pbar_callback.main_progress_bar.total == limit_batches + total_val_batches + assert pbar_callback.main_progress_bar.n == limit_batches + assert pbar_callback.main_progress_bar.total == limit_batches assert pbar_callback.is_enabled From bd07a8b368655d25d086a808fc51397f93a4c5c5 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 00:00:02 +0100 Subject: [PATCH 02/11] update rich bar --- .../callbacks/progress/rich_progress.py | 7 ++--- .../progress/test_rich_progress_bar.py | 27 ++++++++++--------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index de5b83cb545c5..9933adde17bfb 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -362,7 +362,7 @@ def on_sanity_check_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningMod def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if self.is_disabled: return - total_batches = self.total_batches_current_epoch + total_batches = self.total_train_batches train_description = self._get_train_description(trainer.current_epoch) if self.main_progress_bar_id is not None and self._leave: @@ -470,7 +470,7 @@ def on_predict_batch_start( def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: - self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) + self._update(self.main_progress_bar_id, self.train_batch_idx) self._update_metrics(trainer, pl_module) self.refresh() @@ -491,9 +491,6 @@ def on_validation_batch_end( if trainer.sanity_checking: self._update(self.val_sanity_progress_bar_id, self.val_batch_idx) elif self.val_progress_bar_id is not None: - # check to see if we should update the main training progress bar - if self.main_progress_bar_id is not None: - self._update(self.main_progress_bar_id, self.train_batch_idx + self._val_processed) self._update(self.val_progress_bar_id, self.val_batch_idx) self.refresh() diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 4f5a4ca92c0f3..99239c12de93f 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -77,8 +77,8 @@ def predict_dataloader(self): with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked: trainer.fit(model) - # 3 for main progress bar and 1 for val progress bar - assert mocked.call_count == 4 + # 2 for main progress bar and 1 for val progress bar + assert mocked.call_count == 3 with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked: trainer.validate(model) @@ -214,16 +214,17 @@ def test_rich_progress_bar_refresh_rate_disabled(progress_update, tmpdir): @pytest.mark.parametrize( "refresh_rate,train_batches,val_batches,expected_call_count", [ - (3, 6, 6, 4 + 3), - (4, 6, 6, 3 + 3), - (7, 6, 6, 2 + 2), - (1, 2, 3, 5 + 4), + # note: there is always one extra update at the very end (+1) + (3, 6, 6, 2 + 2 + 1), + (4, 6, 6, 2 + 2 + 1), + (7, 6, 6, 1 + 1 + 1), + (1, 2, 3, 2 + 3 + 1), (1, 0, 0, 0 + 0), (3, 1, 0, 1 + 0), - (3, 1, 1, 1 + 2), + (3, 1, 1, 1 + 1 + 1), (3, 5, 0, 2 + 0), - (3, 5, 2, 3 + 2), - (6, 5, 2, 2 + 2), + (3, 5, 2, 2 + 1 + 1), + (6, 5, 2, 1 + 1 + 1), ], ) def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches, val_batches, expected_call_count): @@ -246,8 +247,8 @@ def test_rich_progress_bar_with_refresh_rate(tmpdir, refresh_rate, train_batches if train_batches > 0: fit_main_bar = trainer.progress_bar_callback.progress.tasks[0] - assert fit_main_bar.completed == train_batches + val_batches - assert fit_main_bar.total == train_batches + val_batches + assert fit_main_bar.completed == train_batches + assert fit_main_bar.total == train_batches assert fit_main_bar.visible if val_batches > 0: fit_val_bar = trainer.progress_bar_callback.progress.tasks[1] @@ -294,8 +295,8 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir): trainer.fit(model) fit_main_progress_bar = progress_bar.progress.tasks[1] - assert fit_main_progress_bar.completed == 7 + 3 * 4 - assert fit_main_progress_bar.total == 7 + 3 * 4 + assert fit_main_progress_bar.completed == 7 + assert fit_main_progress_bar.total == 7 fit_val_bar = progress_bar.progress.tasks[2] assert fit_val_bar.completed == 4 From 212b4940e6e6be47043ef4e914717bb3b5416ed4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Feb 2023 23:06:05 +0000 Subject: [PATCH 03/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 88ae2b6d4863b..a4f1b4c222cbb 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -584,7 +584,7 @@ def test_progress_bar_max_val_check_interval( max(1, int(limit_batches * val_check_interval)) if isinstance(val_check_interval, float) else val_check_interval ) assert trainer.val_check_batch == val_check_batch - val_checks_per_epoch = math.ceil(limit_batches // val_check_batch) + math.ceil(limit_batches // val_check_batch) pbar_callback = trainer.progress_bar_callback assert pbar_callback.val_progress_bar.n == limit_batches From 45164c073f720be9ae538fd1695cfc3f35406c0c Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 00:15:36 +0100 Subject: [PATCH 04/11] update test --- .../callbacks/progress/test_tqdm_progress_bar.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 88ae2b6d4863b..210836bcd9525 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -622,16 +622,14 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): total_train_batches = total_train_samples // (train_batch_size * world_size) val_check_batch = max(1, int(total_train_batches * val_check_interval)) assert trainer.val_check_batch == val_check_batch - val_checks_per_epoch = total_train_batches / val_check_batch total_val_batches = total_val_samples // (val_batch_size * world_size) pbar_callback = trainer.progress_bar_callback if trainer.is_global_zero: assert pbar_callback.val_progress_bar.n == total_val_batches assert pbar_callback.val_progress_bar.total == total_val_batches - total_val_batches = total_val_batches * val_checks_per_epoch - assert pbar_callback.main_progress_bar.n == (total_train_batches + total_val_batches) // world_size - assert pbar_callback.main_progress_bar.total == (total_train_batches + total_val_batches) // world_size + assert pbar_callback.main_progress_bar.n == total_train_batches // world_size + assert pbar_callback.main_progress_bar.total == total_train_batches // world_size assert pbar_callback.is_enabled From fe411ce49aac9fc5a3790684c7a502f5004973b1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 00:18:05 +0100 Subject: [PATCH 05/11] changelog --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 21ee565f07c69..f17c671865efd 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -58,6 +58,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The strategy selected by `accelerator="hpu"` now defaults to `find_unused_parameters=False` ([#16611](https://github.com/Lightning-AI/lightning/pull/16611)) +- The main progress bar displayed during training no longer includes the combined progress for validation ([#16695](https://github.com/Lightning-AI/lightning/pull/16695)) + ### Deprecated From 21f6f1914c9cc3b880d532bce5e7fef0e8895962 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 00:21:09 +0100 Subject: [PATCH 06/11] remove redundant test --- .../callbacks/progress/test_base_progress.py | 34 ------------------- 1 file changed, 34 deletions(-) delete mode 100644 tests/tests_pytorch/callbacks/progress/test_base_progress.py diff --git a/tests/tests_pytorch/callbacks/progress/test_base_progress.py b/tests/tests_pytorch/callbacks/progress/test_base_progress.py deleted file mode 100644 index 4ece74473b116..0000000000000 --- a/tests/tests_pytorch/callbacks/progress/test_base_progress.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.trainer import Trainer - - -def test_main_progress_bar_with_val_check_interval_int(): - """Test the main progress bar count when val_check_interval=int and check_val_every_n_epoch=None.""" - train_batches = 5 - trainer = Trainer( - limit_train_batches=train_batches, limit_val_batches=10, val_check_interval=3, check_val_every_n_epoch=None - ) - model = BoringModel() - trainer.progress_bar_callback.setup(trainer, model, stage="fit") - trainer.strategy.connect(model) - trainer._data_connector.attach_data(model) - trainer.reset_train_dataloader() - trainer.reset_val_dataloader() - expected = [15, 25, 25, 15] - - for count in expected: - assert trainer.progress_bar_callback.total_batches_current_epoch == count - trainer.fit_loop.epoch_loop.batch_progress.total.ready += train_batches From 1781ba6e2434c21b77fb99c0a249dc35d7f3e2e3 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 00:54:16 +0100 Subject: [PATCH 07/11] update docs --- docs/source-pytorch/common/progress_bar.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source-pytorch/common/progress_bar.rst b/docs/source-pytorch/common/progress_bar.rst index d00c716bb83cf..45a638570dee2 100644 --- a/docs/source-pytorch/common/progress_bar.rst +++ b/docs/source-pytorch/common/progress_bar.rst @@ -22,7 +22,7 @@ The :class:`~pytorch_lightning.callbacks.TQDMProgressBar` uses the `tqdm Date: Thu, 9 Feb 2023 00:56:29 +0100 Subject: [PATCH 08/11] update docs --- src/lightning/pytorch/callbacks/progress/tqdm_progress.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index b886a34aae831..ec59a25ba3458 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -63,8 +63,8 @@ class TQDMProgressBar(ProgressBarBase): :mod:`tqdm` package and shows up to four different bars: - **sanity check progress:** the progress during the sanity check run - - **main progress:** shows training + validation progress combined. It also accounts for - multiple validation runs during training when + - **main progress:** shows the training progress. It will pause if validation starts and will resume + when it ends, and also accounts for multiple validation runs during training when :paramref:`~lightning.pytorch.trainer.trainer.Trainer.val_check_interval` is used. - **validation progress:** only visible during validation; shows total progress over all validation datasets. From 0c06145eee60afc7e2ea78af27a171592c5ba8cf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 01:09:36 +0100 Subject: [PATCH 09/11] main -> train --- .../callbacks/progress/rich_progress.py | 22 ++++----- .../callbacks/progress/tqdm_progress.py | 48 +++++++++---------- .../progress/test_rich_progress_bar.py | 16 +++---- .../progress/test_tqdm_progress_bar.py | 36 +++++++------- 4 files changed, 61 insertions(+), 61 deletions(-) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 9933adde17bfb..0e7871ab16f9f 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -156,7 +156,7 @@ def render(self, task: "Task") -> Text: if ( self._trainer.state.fn != "fit" or self._trainer.sanity_checking - or self._trainer.progress_bar_callback.main_progress_bar_id != task.id + or self._trainer.progress_bar_callback.train_progress_bar_id != task.id ): return Text() if self._trainer.training and task.id not in self._tasks: @@ -256,7 +256,7 @@ def __init__( self._console_kwargs = console_kwargs or {} self._enabled: bool = True self.progress: Optional[CustomProgress] = None - self.main_progress_bar_id: Optional["TaskID"] + self.train_progress_bar_id: Optional["TaskID"] self.val_sanity_progress_bar_id: Optional["TaskID"] = None self.val_progress_bar_id: Optional["TaskID"] self.test_progress_bar_id: Optional["TaskID"] @@ -280,10 +280,10 @@ def is_disabled(self) -> bool: return not self.is_enabled @property - def main_progress_bar(self) -> Task: + def train_progress_bar(self) -> Task: assert self.progress is not None - assert self.main_progress_bar_id is not None - return self.progress.tasks[self.main_progress_bar_id] + assert self.train_progress_bar_id is not None + return self.progress.tasks[self.train_progress_bar_id] @property def val_sanity_check_bar(self) -> Task: @@ -365,15 +365,15 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo total_batches = self.total_train_batches train_description = self._get_train_description(trainer.current_epoch) - if self.main_progress_bar_id is not None and self._leave: + if self.train_progress_bar_id is not None and self._leave: self._stop_progress() self._init_progress(trainer) if self.progress is not None: - if self.main_progress_bar_id is None: - self.main_progress_bar_id = self._add_task(total_batches, train_description) + if self.train_progress_bar_id is None: + self.train_progress_bar_id = self._add_task(total_batches, train_description) else: self.progress.reset( - self.main_progress_bar_id, total=total_batches, description=train_description, visible=True + self.train_progress_bar_id, total=total_batches, description=train_description, visible=True ) self.refresh() @@ -470,7 +470,7 @@ def on_predict_batch_start( def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int ) -> None: - self._update(self.main_progress_bar_id, self.train_batch_idx) + self._update(self.train_progress_bar_id, self.train_batch_idx) self._update_metrics(trainer, pl_module) self.refresh() @@ -541,7 +541,7 @@ def _stop_progress(self) -> None: self._progress_stopped = True def _reset_progress_bar_ids(self) -> None: - self.main_progress_bar_id = None + self.train_progress_bar_id = None self.val_sanity_progress_bar_id = None self.val_progress_bar_id = None self.test_progress_bar_id = None diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index ec59a25ba3458..0130fde5f57ba 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -103,7 +103,7 @@ def __init__(self, refresh_rate: int = 1, process_position: int = 0): self._refresh_rate = self._resolve_refresh_rate(refresh_rate) self._process_position = process_position self._enabled = True - self._main_progress_bar: Optional[_tqdm] = None + self._train_progress_bar: Optional[_tqdm] = None self._val_progress_bar: Optional[_tqdm] = None self._test_progress_bar: Optional[_tqdm] = None self._predict_progress_bar: Optional[_tqdm] = None @@ -113,14 +113,14 @@ def __getstate__(self) -> Dict: return {k: v if not isinstance(v, _tqdm) else None for k, v in vars(self).items()} @property - def main_progress_bar(self) -> _tqdm: - if self._main_progress_bar is None: - raise TypeError(f"The `{self.__class__.__name__}._main_progress_bar` reference has not been set yet.") - return self._main_progress_bar + def train_progress_bar(self) -> _tqdm: + if self._train_progress_bar is None: + raise TypeError(f"The `{self.__class__.__name__}._train_progress_bar` reference has not been set yet.") + return self._train_progress_bar - @main_progress_bar.setter - def main_progress_bar(self, bar: _tqdm) -> None: - self._main_progress_bar = bar + @train_progress_bar.setter + def train_progress_bar(self, bar: _tqdm) -> None: + self._train_progress_bar = bar @property def val_progress_bar(self) -> _tqdm: @@ -242,32 +242,32 @@ def init_test_tqdm(self) -> Tqdm: def on_sanity_check_start(self, *_: Any) -> None: self.val_progress_bar = self.init_sanity_tqdm() - self.main_progress_bar = Tqdm(disable=True) # dummy progress bar + self.train_progress_bar = Tqdm(disable=True) # dummy progress bar def on_sanity_check_end(self, *_: Any) -> None: - self.main_progress_bar.close() + self.train_progress_bar.close() self.val_progress_bar.close() def on_train_start(self, *_: Any) -> None: - self.main_progress_bar = self.init_train_tqdm() + self.train_progress_bar = self.init_train_tqdm() def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: - self.main_progress_bar.reset(convert_inf(self.total_train_batches)) - self.main_progress_bar.initial = 0 - self.main_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + self.train_progress_bar.reset(convert_inf(self.total_train_batches)) + self.train_progress_bar.initial = 0 + self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") def on_train_batch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", *_: Any) -> None: current = self.train_batch_idx - if self._should_update(current, self.main_progress_bar.total): - _update_n(self.main_progress_bar, current) - self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + if self._should_update(current, self.train_progress_bar.total): + _update_n(self.train_progress_bar, current) + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if not self.main_progress_bar.disable: - self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + if not self.train_progress_bar.disable: + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) def on_train_end(self, *_: Any) -> None: - self.main_progress_bar.close() + self.train_progress_bar.close() def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if not trainer.sanity_checking: @@ -289,8 +289,8 @@ def on_validation_batch_end(self, trainer: "pl.Trainer", *_: Any) -> None: _update_n(self.val_progress_bar, self.val_batch_idx) def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - if self._main_progress_bar is not None and trainer.state.fn == "fit": - self.main_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + if self._train_progress_bar is not None and trainer.state.fn == "fit": + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) self.val_progress_bar.close() self.reset_dataloader_idx_tracker() @@ -339,8 +339,8 @@ def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") def print(self, *args: Any, sep: str = " ", **kwargs: Any) -> None: active_progress_bar = None - if self._main_progress_bar is not None and not self.main_progress_bar.disable: - active_progress_bar = self.main_progress_bar + if self._train_progress_bar is not None and not self.train_progress_bar.disable: + active_progress_bar = self.train_progress_bar elif self._val_progress_bar is not None and not self.val_progress_bar.disable: active_progress_bar = self.val_progress_bar elif self._test_progress_bar is not None and not self.test_progress_bar.disable: diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 99239c12de93f..b3b5bfc2b75de 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -294,9 +294,9 @@ def test_rich_progress_bar_counter_with_val_check_interval(tmpdir): ) trainer.fit(model) - fit_main_progress_bar = progress_bar.progress.tasks[1] - assert fit_main_progress_bar.completed == 7 - assert fit_main_progress_bar.total == 7 + fit_train_progress_bar = progress_bar.progress.tasks[1] + assert fit_train_progress_bar.completed == 7 + assert fit_train_progress_bar.total == 7 fit_val_bar = progress_bar.progress.tasks[2] assert fit_val_bar.completed == 4 @@ -335,12 +335,12 @@ def training_step(self, *args, **kwargs): trainer = Trainer(default_root_dir=tmpdir, callbacks=progress_bar, fast_dev_run=True, logger=CSVLogger(tmpdir)) trainer.fit(model) - main_progress_bar_id = progress_bar.main_progress_bar_id + train_progress_bar_id = progress_bar.train_progress_bar_id val_progress_bar_id = progress_bar.val_progress_bar_id rendered = progress_bar.progress.columns[-1]._renderable_cache for key in ("loss", "v_num", "train_loss"): - assert key in rendered[main_progress_bar_id][1] + assert key in rendered[train_progress_bar_id][1] assert key not in rendered[val_progress_bar_id][1] @@ -455,7 +455,7 @@ def test_rich_progress_bar_reset_bars(): assert bar._progress_stopped is False def _set_fake_bar_ids(): - bar.main_progress_bar_id = 0 + bar.train_progress_bar_id = 0 bar.val_sanity_progress_bar_id = 1 bar.val_progress_bar_id = 2 bar.test_progress_bar_id = 3 @@ -470,7 +470,7 @@ def _set_fake_bar_ids(): bar.teardown(Mock(), Mock(), Mock()) # assert all bars are reset - assert bar.main_progress_bar_id is None + assert bar.train_progress_bar_id is None assert bar.val_sanity_progress_bar_id is None assert bar.val_progress_bar_id is None assert bar.test_progress_bar_id is None @@ -507,7 +507,7 @@ def test_rich_progress_bar_disabled(tmpdir): trainer.predict(model) mocked.assert_not_called() - assert bar.main_progress_bar_id is None + assert bar.train_progress_bar_id is None assert bar.val_sanity_progress_bar_id is None assert bar.val_progress_bar_id is None assert bar.test_progress_bar_id is None diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index a101958eea32f..e52a6aa8f3929 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -163,9 +163,9 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): m = trainer.num_val_batches assert len(trainer.train_dataloader) == n # main progress bar should have reached the end - assert pbar.main_progress_bar.total == n - assert pbar.main_progress_bar.n == n - assert pbar.main_progress_bar.leave + assert pbar.train_progress_bar.total == n + assert pbar.train_progress_bar.n == n + assert pbar.train_progress_bar.leave # check val progress bar total assert pbar.val_progress_bar.total_values == m @@ -215,8 +215,8 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): assert 1 == pbar.val_progress_bar.total # the main progress bar should display 1 batch - assert 1 == pbar.main_progress_bar.total - assert 1 == pbar.main_progress_bar.n + assert 1 == pbar.train_progress_bar.total + assert 1 == pbar.train_progress_bar.n trainer.validate(model) @@ -266,17 +266,17 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert trainer.progress_bar_callback.refresh_rate == refresh_rate trainer.fit(model) - assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total + assert pbar.train_batches_seen == 3 * pbar.train_progress_bar.total assert pbar.val_batches_seen == 3 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == 0 trainer.validate(model) - assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total + assert pbar.train_batches_seen == 3 * pbar.train_progress_bar.total assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == 0 trainer.test(model) - assert pbar.train_batches_seen == 3 * pbar.main_progress_bar.total + assert pbar.train_batches_seen == 3 * pbar.train_progress_bar.total assert pbar.val_batches_seen == 4 * pbar.val_progress_bar.total + trainer.num_sanity_val_steps assert pbar.test_batches_seen == pbar.test_progress_bar.total @@ -348,7 +348,7 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): [5, 2, 6, [0, 5], [0, 2]], ], ) -def test_main_progress_bar_update_amount( +def test_train_progress_bar_update_amount( tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_updates, val_updates ): """Test that the main progress updates with the correct amount together with the val progress. @@ -370,7 +370,7 @@ def test_main_progress_bar_update_amount( with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): trainer.fit(model) if train_batches > 0: - assert progress_bar.main_progress_bar.n_values == train_updates + assert progress_bar.train_progress_bar.n_values == train_updates if val_batches > 0: assert progress_bar.val_progress_bar.n_values == val_updates @@ -411,7 +411,7 @@ def training_step(self, batch, batch_idx): torch.testing.assert_close(trainer.progress_bar_metrics["a"], 0.123) assert trainer.progress_bar_metrics["b"] == 1.0 assert trainer.progress_bar_metrics["c"] == 2.0 - pbar = trainer.progress_bar_callback.main_progress_bar + pbar = trainer.progress_bar_callback.train_progress_bar actual = str(pbar.postfix) assert actual.endswith("a=0.123, b=1.000, c=2.000"), actual @@ -555,11 +555,11 @@ def test_tqdm_progress_bar_can_be_pickled(): @pytest.mark.parametrize( - ["val_check_interval", "main_progress_bar_updates", "val_progress_bar_updates"], + ["val_check_interval", "train_progress_bar_updates", "val_progress_bar_updates"], [(4, [0, 3, 6, 7], [0, 3, 6, 7]), (0.5, [0, 3, 6, 7], [0, 3, 6, 7])], ) def test_progress_bar_max_val_check_interval( - tmpdir, val_check_interval, main_progress_bar_updates, val_progress_bar_updates + tmpdir, val_check_interval, train_progress_bar_updates, val_progress_bar_updates ): limit_batches = 7 model = BoringModel() @@ -577,7 +577,7 @@ def test_progress_bar_max_val_check_interval( trainer.fit(model) pbar = trainer.progress_bar_callback - assert pbar.main_progress_bar.n_values == main_progress_bar_updates + assert pbar.train_progress_bar.n_values == train_progress_bar_updates assert pbar.val_progress_bar.n_values == val_progress_bar_updates val_check_batch = ( @@ -589,8 +589,8 @@ def test_progress_bar_max_val_check_interval( assert pbar_callback.val_progress_bar.n == limit_batches assert pbar_callback.val_progress_bar.total == limit_batches - assert pbar_callback.main_progress_bar.n == limit_batches - assert pbar_callback.main_progress_bar.total == limit_batches + assert pbar_callback.train_progress_bar.n == limit_batches + assert pbar_callback.train_progress_bar.total == limit_batches assert pbar_callback.is_enabled @@ -628,8 +628,8 @@ def test_progress_bar_max_val_check_interval_ddp(tmpdir, val_check_interval): if trainer.is_global_zero: assert pbar_callback.val_progress_bar.n == total_val_batches assert pbar_callback.val_progress_bar.total == total_val_batches - assert pbar_callback.main_progress_bar.n == total_train_batches // world_size - assert pbar_callback.main_progress_bar.total == total_train_batches // world_size + assert pbar_callback.train_progress_bar.n == total_train_batches // world_size + assert pbar_callback.train_progress_bar.total == total_train_batches // world_size assert pbar_callback.is_enabled From 5f193fc88ea3836ebf03b46b5299a0b258e0af04 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 18:49:51 +0100 Subject: [PATCH 10/11] left-over renames --- docs/source-pytorch/common/progress_bar.rst | 2 +- src/lightning/pytorch/callbacks/progress/base.py | 2 +- src/lightning/pytorch/callbacks/progress/tqdm_progress.py | 4 ++-- .../callbacks/progress/test_rich_progress_bar.py | 2 +- .../callbacks/progress/test_tqdm_progress_bar.py | 6 +++--- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/source-pytorch/common/progress_bar.rst b/docs/source-pytorch/common/progress_bar.rst index 45a638570dee2..e03c6491edc3e 100644 --- a/docs/source-pytorch/common/progress_bar.rst +++ b/docs/source-pytorch/common/progress_bar.rst @@ -22,7 +22,7 @@ The :class:`~pytorch_lightning.callbacks.TQDMProgressBar` uses the `tqdm None: The :class:`~lightning.pytorch.trainer.trainer.Trainer` will call this in e.g. pre-training routines like the :ref:`learning rate finder `. - to temporarily enable and disable the main progress bar. + to temporarily enable and disable the training progress bar. """ raise NotImplementedError diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 0130fde5f57ba..bea7251ad168a 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -63,7 +63,7 @@ class TQDMProgressBar(ProgressBarBase): :mod:`tqdm` package and shows up to four different bars: - **sanity check progress:** the progress during the sanity check run - - **main progress:** shows the training progress. It will pause if validation starts and will resume + - **train progress:** shows the training progress. It will pause if validation starts and will resume when it ends, and also accounts for multiple validation runs during training when :paramref:`~lightning.pytorch.trainer.trainer.Trainer.val_check_interval` is used. - **validation progress:** only visible during validation; @@ -216,7 +216,7 @@ def init_predict_tqdm(self) -> Tqdm: def init_validation_tqdm(self) -> Tqdm: """Override this to customize the tqdm bar for validation.""" - # The main progress bar doesn't exist in `trainer.validate()` + # The train progress bar doesn't exist in `trainer.validate()` has_main_bar = self.trainer.state.fn != "validate" bar = Tqdm( desc=self.validation_description, diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index b3b5bfc2b75de..6800d622e8fd4 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -77,7 +77,7 @@ def predict_dataloader(self): with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked: trainer.fit(model) - # 2 for main progress bar and 1 for val progress bar + # 2 for train progress bar and 1 for val progress bar assert mocked.call_count == 3 with mock.patch("lightning.pytorch.callbacks.progress.rich_progress.Progress.update") as mocked: diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index e52a6aa8f3929..2181f5a842aac 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -162,7 +162,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=None): n = trainer.num_training_batches m = trainer.num_val_batches assert len(trainer.train_dataloader) == n - # main progress bar should have reached the end + # train progress bar should have reached the end assert pbar.train_progress_bar.total == n assert pbar.train_progress_bar.n == n assert pbar.train_progress_bar.leave @@ -214,7 +214,7 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): assert 1 == pbar.val_progress_bar.n assert 1 == pbar.val_progress_bar.total - # the main progress bar should display 1 batch + # the train progress bar should display 1 batch assert 1 == pbar.train_progress_bar.total assert 1 == pbar.train_progress_bar.n @@ -351,7 +351,7 @@ def test_tqdm_progress_bar_value_on_colab(tmpdir): def test_train_progress_bar_update_amount( tmpdir, train_batches: int, val_batches: int, refresh_rate: int, train_updates, val_updates ): - """Test that the main progress updates with the correct amount together with the val progress. + """Test that the train progress updates with the correct amount together with the val progress. At the end of the epoch, the progress must not overshoot if the number of steps is not divisible by the refresh rate. From bdccf1a06bdef7f69da01dbda16c62287c385556 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Thu, 9 Feb 2023 18:52:14 +0100 Subject: [PATCH 11/11] changelog --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index f17c671865efd..71befd9ecedb8 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - The main progress bar displayed during training no longer includes the combined progress for validation ([#16695](https://github.com/Lightning-AI/lightning/pull/16695)) +- Renamed `TQDMProgressBar.main_progress_bar` to `TQDMProgressBar.train_progress_bar` ([#16695](https://github.com/Lightning-AI/lightning/pull/16695)) + ### Deprecated