diff --git a/pytorch_lightning/callbacks/progress/rich_progress.py b/pytorch_lightning/callbacks/progress/rich_progress.py index 1fb5307368483..95064a33c5152 100644 --- a/pytorch_lightning/callbacks/progress/rich_progress.py +++ b/pytorch_lightning/callbacks/progress/rich_progress.py @@ -37,7 +37,7 @@ def render(self, task: "Task") -> ProgressBar: total=max(0, task.total), completed=max(0, task.completed), width=None if self.bar_width is None else max(1, self.bar_width), - pulse=not task.started or math.isfinite(task.remaining), + pulse=not task.started or not math.isfinite(task.remaining), animation_time=task.get_time(), style=self.style, complete_style=self.complete_style, @@ -129,13 +129,19 @@ def render(self, task) -> RenderableType: class MetricsTextColumn(ProgressColumn): """A column containing text.""" - def __init__(self, trainer, pl_module): + def __init__(self, trainer): self._trainer = trainer - self._pl_module = pl_module self._tasks = {} self._current_task_id = 0 + self.metrics = {} super().__init__() + def update(self, metrics): + # called when metrics are ready to be rendered. + # this is due to preventing render from causing deadlock issues by requesting metrics + # in separate thread. + self.metrics = metrics + def render(self, task) -> Text: from pytorch_lightning.trainer.states import TrainerFn @@ -149,14 +155,8 @@ def render(self, task) -> Text: if self._trainer.training and task.id != self._current_task_id: return self._tasks[task.id] _text = "" - # TODO(@daniellepintz): make this code cleaner - progress_bar_callback = getattr(self._trainer, "progress_bar_callback", None) - if progress_bar_callback: - metrics = self._trainer.progress_bar_callback.get_metrics(self._trainer, self._pl_module) - else: - metrics = self._trainer.progress_bar_metrics - - for k, v in metrics.items(): + + for k, v in self.metrics.items(): _text += f"{k}: {round(v, 3) if isinstance(v, float) else v} " return Text(_text, justify="left") @@ -194,7 +194,11 @@ class RichProgressBar(ProgressBarBase): trainer = Trainer(callbacks=RichProgressBar()) Args: - refresh_rate_per_second: the number of updates per second. If refresh_rate is 0, progress bar is disabled. + refresh_rate: Determines at which rate (in number of batches) the progress bars get updated. + Set it to ``0`` to disable the display. By default, the :class:`~pytorch_lightning.trainer.trainer.Trainer` + uses this implementation of the progress bar and sets the refresh rate to the value provided to the + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.progress_bar_refresh_rate` argument in the + :class:`~pytorch_lightning.trainer.trainer.Trainer`. theme: Contains styles used to stylize the progress bar. Raises: @@ -204,7 +208,7 @@ class RichProgressBar(ProgressBarBase): def __init__( self, - refresh_rate_per_second: int = 10, + refresh_rate: int = 1, theme: RichProgressBarTheme = RichProgressBarTheme(), ) -> None: if not _RICH_AVAILABLE: @@ -212,27 +216,22 @@ def __init__( "`RichProgressBar` requires `rich` to be installed. Install it by running `pip install -U rich`." ) super().__init__() - self._refresh_rate_per_second: int = refresh_rate_per_second + self._refresh_rate: int = refresh_rate self._enabled: bool = True self.progress: Optional[Progress] = None self.val_sanity_progress_bar_id: Optional[int] = None self._reset_progress_bar_ids() + self._metric_component = None self._progress_stopped: bool = False self.theme = theme - self._console: Console = Console() @property - def refresh_rate_per_second(self) -> float: - """Refresh rate for Rich Progress. - - Returns: Refresh rate for Progress Bar. - Return 1 if not enabled, as a positive integer is required (ignored by Rich Progress). - """ - return self._refresh_rate_per_second if self._refresh_rate_per_second > 0 else 1 + def refresh_rate(self) -> float: + return self._refresh_rate @property def is_enabled(self) -> bool: - return self._enabled and self._refresh_rate_per_second > 0 + return self._enabled and self.refresh_rate > 0 @property def is_disabled(self) -> bool: @@ -260,10 +259,12 @@ def test_description(self) -> str: def predict_description(self) -> str: return "Predicting" - def _init_progress(self, trainer, pl_module): - if self.progress is None or self._progress_stopped: + def _init_progress(self, trainer): + if self.is_enabled and (self.progress is None or self._progress_stopped): self._reset_progress_bar_ids() + self._console: Console = Console() self._console.clear_live() + self._metric_component = MetricsTextColumn(trainer) self.progress = CustomProgress( TextColumn("[progress.description]{task.description}"), CustomBarColumn( @@ -274,8 +275,8 @@ def _init_progress(self, trainer, pl_module): BatchesProcessedColumn(style=self.theme.batch_process), CustomTimeColumn(style=self.theme.time), ProcessingSpeedColumn(style=self.theme.processing_speed), - MetricsTextColumn(trainer, pl_module), - refresh_per_second=self.refresh_rate_per_second, + self._metric_component, + auto_refresh=False, disable=self.is_disabled, console=self._console, ) @@ -283,42 +284,47 @@ def _init_progress(self, trainer, pl_module): # progress has started self._progress_stopped = False + def refresh(self): + if self.progress: + self.progress.refresh() + def on_train_start(self, trainer, pl_module): super().on_train_start(trainer, pl_module) - self._init_progress(trainer, pl_module) + self._init_progress(trainer) def on_predict_start(self, trainer, pl_module): super().on_predict_start(trainer, pl_module) - self._init_progress(trainer, pl_module) + self._init_progress(trainer) def on_test_start(self, trainer, pl_module): super().on_test_start(trainer, pl_module) - self._init_progress(trainer, pl_module) + self._init_progress(trainer) def on_validation_start(self, trainer, pl_module): super().on_validation_start(trainer, pl_module) - self._init_progress(trainer, pl_module) + self._init_progress(trainer) def __getstate__(self): # can't pickle the rich progress objects state = self.__dict__.copy() - state["progress"] = None state["_console"] = None + state["progress"] = None return state def __setstate__(self, state): self.__dict__ = state - # reset console reference after loading progress - self._console = Console() + state["_console"] = Console() def on_sanity_check_start(self, trainer, pl_module): super().on_sanity_check_start(trainer, pl_module) - self._init_progress(trainer, pl_module) + self._init_progress(trainer) self.val_sanity_progress_bar_id = self._add_task(trainer.num_sanity_val_steps, self.sanity_check_description) + self.refresh() def on_sanity_check_end(self, trainer, pl_module): super().on_sanity_check_end(trainer, pl_module) self._update(self.val_sanity_progress_bar_id, visible=False) + self.refresh() def on_train_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) @@ -334,7 +340,9 @@ def on_train_epoch_start(self, trainer, pl_module): train_description = self._get_train_description(trainer.current_epoch) if self.main_progress_bar_id is None: self.main_progress_bar_id = self._add_task(total_batches, train_description) - self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description) + if self.progress is not None: + self.progress.reset(self.main_progress_bar_id, total=total_batches, description=train_description) + self.refresh() def on_validation_epoch_start(self, trainer, pl_module): super().on_validation_epoch_start(trainer, pl_module) @@ -345,6 +353,7 @@ def on_validation_epoch_start(self, trainer, pl_module): val_checks_per_epoch = self.total_train_batches // trainer.val_check_batch total_val_batches = self.total_val_batches * val_checks_per_epoch self.val_progress_bar_id = self._add_task(total_val_batches, self.validation_description, visible=False) + self.refresh() def _add_task(self, total_batches: int, description: str, visible: bool = True) -> Optional[int]: if self.progress is not None: @@ -352,44 +361,59 @@ def _add_task(self, total_batches: int, description: str, visible: bool = True) f"[{self.theme.text_color}]{description}", total=total_batches, visible=visible ) - def _update(self, progress_bar_id: int, visible: bool = True) -> None: - if self.progress is not None: + def _update(self, progress_bar_id: int, current: int, total: int, visible: bool = True) -> None: + if self.progress is not None and self._should_update(current, total): self.progress.update(progress_bar_id, advance=1.0, visible=visible) + self.refresh() + + def _should_update(self, current, total) -> bool: + return self.is_enabled and (current % self.refresh_rate == 0 or current == total) def on_validation_epoch_end(self, trainer, pl_module): super().on_validation_epoch_end(trainer, pl_module) if self.val_progress_bar_id is not None: - self._update(self.val_progress_bar_id, visible=False) + self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches, visible=False) + + def on_validation_end(self, trainer, pl_module) -> None: + super().on_validation_end(trainer, pl_module) + self._update_metrics(trainer, pl_module) def on_test_epoch_start(self, trainer, pl_module): super().on_train_epoch_start(trainer, pl_module) self.test_progress_bar_id = self._add_task(self.total_test_batches, self.test_description) + self.refresh() def on_predict_epoch_start(self, trainer, pl_module): super().on_predict_epoch_start(trainer, pl_module) self.predict_progress_bar_id = self._add_task(self.total_predict_batches, self.predict_description) + self.refresh() def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) - self._update(self.main_progress_bar_id) + self._update(self.main_progress_bar_id, self.train_batch_idx, self.total_train_batches) + self._update_metrics(trainer, pl_module) + self.refresh() def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) if trainer.sanity_checking: - self._update(self.val_sanity_progress_bar_id) + self._update(self.val_sanity_progress_bar_id, self.val_batch_idx, self.total_val_batches) elif self.val_progress_bar_id is not None: # check to see if we should update the main training progress bar if self.main_progress_bar_id is not None: - self._update(self.main_progress_bar_id) - self._update(self.val_progress_bar_id) + self._update(self.main_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self._update(self.val_progress_bar_id, self.val_batch_idx, self.total_val_batches) + self.refresh() def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self._update(self.test_progress_bar_id) + self._update(self.test_progress_bar_id, self.test_batch_idx, self.total_test_batches) + self.refresh() def on_predict_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): super().on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx) - self._update(self.predict_progress_bar_id) + self._update(self.predict_progress_bar_id, self.predict_batch_idx, self.total_predict_batches) + self.refresh() def _get_train_description(self, current_epoch: int) -> str: train_description = f"Epoch {current_epoch}" @@ -414,6 +438,11 @@ def _reset_progress_bar_ids(self): self.test_progress_bar_id: Optional[int] = None self.predict_progress_bar_id: Optional[int] = None + def _update_metrics(self, trainer, pl_module) -> None: + metrics = self.get_metrics(trainer, pl_module) + if self._metric_component: + self._metric_component.update(metrics) + def teardown(self, trainer, pl_module, stage: Optional[str] = None) -> None: self._stop_progress() diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 4d41734ed90e6..aac51b14ff9c1 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -26,9 +26,9 @@ ) from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.utilities import ModelSummaryMode, rank_zero_info +from pytorch_lightning.utilities import _RICH_AVAILABLE, ModelSummaryMode, rank_zero_info from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.warnings import rank_zero_deprecation +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn class CallbackConnector: @@ -216,11 +216,7 @@ def _configure_swa_callbacks(self): self.trainer.callbacks = [StochasticWeightAveraging()] + self.trainer.callbacks def configure_progress_bar(self, refresh_rate=None, process_position=0): - if os.getenv("COLAB_GPU") and refresh_rate is None: - # smaller refresh rate on colab causes crashes, choose a higher value - refresh_rate = 20 - refresh_rate = 1 if refresh_rate is None else refresh_rate - + # if progress bar callback already exists return it progress_bars = [c for c in self.trainer.callbacks if isinstance(c, ProgressBarBase)] if len(progress_bars) > 1: raise MisconfigurationException( @@ -228,14 +224,30 @@ def configure_progress_bar(self, refresh_rate=None, process_position=0): " progress bar is supported." ) if len(progress_bars) == 1: - progress_bar_callback = progress_bars[0] - elif refresh_rate > 0: + return progress_bars[0] + # check if progress bar has been turned off (i.e refresh_rate == 0) + if refresh_rate == 0: + return + # if Rich is available and refresh_rate is None return Rich ProgressBar + if _RICH_AVAILABLE: + if refresh_rate is None: + progress_bar_callback = RichProgressBar() + self.trainer.callbacks.append(progress_bar_callback) + return progress_bar_callback + rank_zero_warn( + "`RichProgressBar` does not support setting the refresh rate via the Trainer." + " If you'd like to change the refresh rate and continue using the `RichProgressBar`," + " please pass `callbacks=RichProgressBar(refresh_rate=X)`." + " Setting to the `TQDM ProgressBar`." + ) + # else return new TQDMProgressBar + if os.getenv("COLAB_GPU") and refresh_rate is None: + # smaller refresh rate on colab causes crashes for TQDM, choose a higher value + refresh_rate = 20 + refresh_rate = 1 if refresh_rate is None else refresh_rate progress_bar_callback = TQDMProgressBar(refresh_rate=refresh_rate, process_position=process_position) self.trainer.callbacks.append(progress_bar_callback) - else: - progress_bar_callback = None - - return progress_bar_callback + return progress_bar_callback def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dict[str, int]]] = None) -> None: if max_time is None: diff --git a/test.py b/test.py new file mode 100644 index 0000000000000..e2c6b18286f81 --- /dev/null +++ b/test.py @@ -0,0 +1,58 @@ +import os + +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning import LightningModule, Trainer + + +class RandomDataset(Dataset): + def __init__(self, size, length): + self.len = length + self.data = torch.randn(length, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len + + +class BoringModel(LightningModule): + def __init__(self): + super().__init__() + self.layer = torch.nn.Linear(32, 2) + + def forward(self, x): + return self.layer(x) + + def training_step(self, batch, batch_idx): + loss = self(batch).sum() + self.log("train_loss", loss) + return {"loss": loss} + + def on_after_backward(self) -> None: + print("Engine", {name: p.grad for name, p in self.trainer.model.named_parameters()}) + print("LM", {name: p.grad for name, p in self.named_parameters()}) + + def configure_optimizers(self): + return torch.optim.SGD(self.layer.parameters(), lr=0.1) + + +def run(): + train_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + max_epochs=1, + enable_model_summary=False, + strategy="deepspeed", + gpus=1, + ) + trainer.fit(model, train_dataloaders=train_data) + + +if __name__ == "__main__": + run() diff --git a/tests/callbacks/test_rich_progress_bar.py b/tests/callbacks/test_rich_progress_bar.py index 708cbafa4d65a..6a41494ccc1c8 100644 --- a/tests/callbacks/test_rich_progress_bar.py +++ b/tests/callbacks/test_rich_progress_bar.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pickle from unittest import mock from unittest.mock import DEFAULT @@ -18,7 +19,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ProgressBarBase, RichProgressBar +from pytorch_lightning.callbacks import ProgressBar, ProgressBarBase, RichProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.utilities.imports import _RICH_AVAILABLE from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset @@ -36,7 +37,7 @@ def test_rich_progress_bar_callback(): @RunIf(rich=True) -def test_rich_progress_bar_refresh_rate(): +def test_rich_progress_bar_refresh_rate_enable_disable(): progress_bar = RichProgressBar(refresh_rate_per_second=1) assert progress_bar.is_enabled assert not progress_bar.is_disabled @@ -45,6 +46,23 @@ def test_rich_progress_bar_refresh_rate(): assert progress_bar.is_disabled +@RunIf(rich=True) +def test_rich_progress_bar_refresh_rate(tmpdir): + """Test that the refresh rate is set correctly based on the Trainer, and warn if the user sets the argument.""" + trainer = Trainer(default_root_dir=tmpdir) + assert trainer.progress_bar_callback.refresh_rate_per_second == 10 + assert isinstance(trainer.progress_bar_callback, RichProgressBar) + + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=None) + assert isinstance(trainer.progress_bar_callback, RichProgressBar) + assert trainer.progress_bar_callback.refresh_rate_per_second == 10 + + with pytest.warns(UserWarning, match="does not support setting the refresh rate via the Trainer."): + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=19) + assert isinstance(trainer.progress_bar_callback, ProgressBar) + assert trainer.progress_bar_callback.refresh_rate == 19 + + @RunIf(rich=True) @mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress.update") @pytest.mark.parametrize("dataset", [RandomDataset(32, 64), RandomIterableDataset(32, 64)]) @@ -90,7 +108,8 @@ def test_rich_progress_bar_import_error(): @RunIf(rich=True) -def test_rich_progress_bar_custom_theme(tmpdir): +@mock.patch("pytorch_lightning.callbacks.progress.rich_progress.Progress") +def test_rich_progress_bar_custom_theme(mock_progress, tmpdir): """Test to ensure that custom theme styles are used.""" with mock.patch.multiple( "pytorch_lightning.callbacks.progress.rich_progress", @@ -141,3 +160,132 @@ def on_train_start(self) -> None: trainer.fit(model) mock_progress_stop.assert_called_once() + trainer.progress_bar_callback.teardown(trainer, model) + + +@RunIf(rich=True) +def test_progress_bar_totals(tmpdir): + """Test that the progress finishes with the correct total steps processed.""" + + model = BoringModel() + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + bar = trainer.progress_bar_callback + assert float("inf") == bar.total_train_batches + assert 0 == bar.total_val_batches + assert 0 == bar.total_test_batches + + trainer.fit(model) + + # check main progress bar total + n = bar.total_train_batches + m = bar.total_val_batches + + assert len(trainer.train_dataloader) == n + assert bar.main_progress_bar.total == n + m + + # check val progress bar total + assert sum(len(loader) for loader in trainer.val_dataloaders) == m + assert bar.val_progress_bar.total == m + + # check that the test progress bar is off + assert bar.total_test_batches == 0 + + trainer.validate(model) + + assert bar.val_progress_bar.total == m + assert bar.val_batch_idx == m + + trainer.test(model) + + # check test progress bar total + k = bar.total_test_batches + assert sum(len(loader) for loader in trainer.test_dataloaders) == k + assert bar.test_progress_bar.total == k + assert bar.test_batch_idx == k + + +@RunIf(rich=True) +def test_progress_bar_fast_dev_run(tmpdir): + model = BoringModel() + + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + + trainer.fit(model) + + bar = trainer.progress_bar_callback + + assert 1 == bar.total_train_batches + # total val batches are known only after val dataloaders have reloaded + + assert 1 == bar.total_val_batches + assert 1 == bar.train_batch_idx + assert 1 == bar.val_batch_idx + assert 0 == bar.test_batch_idx + + # the main progress bar should display 2 batches (1 train, 1 val) + assert 2 == bar.main_progress_bar.total + + trainer.validate(model) + + # the validation progress bar should display 1 batch + assert 1 == bar.val_batch_idx + assert 1 == bar.val_progress_bar.total + + trainer.test(model) + + # the test progress bar should display 1 batch + assert 1 == bar.test_batch_idx + assert 1 == bar.test_progress_bar.total + + +@RunIf(rich=True) +@pytest.mark.parametrize("limit_val_batches", (0, 5)) +def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int): + """Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.""" + + class CurrentProgressBar(RichProgressBar): + val_pbar_total = 0 + sanity_pbar_total = 0 + + def on_sanity_check_end(self, *args): + super().on_sanity_check_end(*args) + self.sanity_pbar_total = self.val_sanity_check_bar.total + + def on_validation_epoch_end(self, *args): + super().on_validation_epoch_end(*args) + self.val_pbar_total = self.val_progress_bar.total + + model = BoringModel() + progress_bar = CurrentProgressBar() + num_sanity_val_steps = 2 + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + num_sanity_val_steps=num_sanity_val_steps, + limit_train_batches=1, + limit_val_batches=limit_val_batches, + callbacks=[progress_bar], + logger=False, + checkpoint_callback=False, + ) + trainer.fit(model) + + assert progress_bar.sanity_pbar_total == min(num_sanity_val_steps, limit_val_batches) + assert progress_bar.val_pbar_total == limit_val_batches + + +@RunIf(rich=True) +def test_progress_bar_can_be_pickled(): + bar = RichProgressBar() + trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) + model = BoringModel() + + pickle.dumps(bar) + trainer.fit(model) + pickle.dumps(bar) + trainer.test(model) + pickle.dumps(bar) + trainer.predict(model) + pickle.dumps(bar) diff --git a/tests/callbacks/test_tqdm_progress_bar.py b/tests/callbacks/test_tqdm_progress_bar.py index b92fb18d54ccd..d4edad01149ea 100644 --- a/tests/callbacks/test_tqdm_progress_bar.py +++ b/tests/callbacks/test_tqdm_progress_bar.py @@ -23,9 +23,10 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, TQDMProgressBar +from pytorch_lightning.callbacks import ModelCheckpoint, ProgressBarBase, RichProgressBar, TQDMProgressBar from pytorch_lightning.callbacks.progress.tqdm_progress import Tqdm from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import _RICH_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -73,7 +74,7 @@ def test_tqdm_progress_bar_off(tmpdir, callbacks: list, refresh_rate: Union[bool enable_progress_bar=enable_progress_bar, ) - progress_bars = [c for c in trainer.callbacks if isinstance(c, TQDMProgressBar)] + progress_bars = [c for c in trainer.callbacks if isinstance(c, ProgressBarBase)] assert 0 == len(progress_bars) assert not trainer.progress_bar_callback @@ -90,7 +91,7 @@ def test_tqdm_progress_bar_totals(tmpdir): model = BoringModel() - trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=1, max_epochs=1) + trainer = Trainer(default_root_dir=tmpdir, progress_bar_refresh_rate=1, max_epochs=1, callbacks=TQDMProgressBar()) bar = trainer.progress_bar_callback assert float("inf") == bar.total_train_batches assert 0 == bar.total_val_batches @@ -146,6 +147,12 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): trainer.fit(model) progress_bar = trainer.progress_bar_callback + + train_bar, val_bar = ( + progress_bar.progress.tasks[progress_bar.main_progress_bar_id], + progress_bar.progress.tasks[progress_bar.val_progress_bar_id], + ) + assert 1 == progress_bar.total_train_batches # total val batches are known only after val dataloaders have reloaded @@ -155,22 +162,21 @@ def test_tqdm_progress_bar_fast_dev_run(tmpdir): assert 0 == progress_bar.test_batch_idx # the main progress bar should display 2 batches (1 train, 1 val) - assert 2 == progress_bar.main_progress_bar.total - assert 2 == progress_bar.main_progress_bar.n + assert 2 == train_bar.total trainer.validate(model) # the validation progress bar should display 1 batch assert 1 == progress_bar.val_batch_idx - assert 1 == progress_bar.val_progress_bar.total - assert 1 == progress_bar.val_progress_bar.n + assert 1 == val_bar.total trainer.test(model) + test_bar = progress_bar.progress.tasks[progress_bar.test_progress_bar_id] + # the test progress bar should display 1 batch assert 1 == progress_bar.test_batch_idx - assert 1 == progress_bar.test_progress_bar.total - assert 1 == progress_bar.test_progress_bar.n + assert 1 == test_bar.total @pytest.mark.parametrize("refresh_rate", [0, 1, 50]) @@ -179,7 +185,7 @@ def test_tqdm_progress_bar_progress_refresh(tmpdir, refresh_rate: int): model = BoringModel() - class CurrentProgressBar(TQDMProgressBar): + class CurrentTQDMProgressBar(TQDMProgressBar): train_batches_seen = 0 val_batches_seen = 0 @@ -208,7 +214,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal assert self.test_progress_bar.n == self.test_batch_idx self.test_batches_seen += 1 - progress_bar = CurrentProgressBar(refresh_rate=refresh_rate) + progress_bar = CurrentTQDMProgressBar(refresh_rate=refresh_rate) trainer = Trainer( default_root_dir=tmpdir, callbacks=[progress_bar], @@ -236,7 +242,7 @@ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, datal @pytest.mark.parametrize("limit_val_batches", (0, 5)) -def test_num_sanity_val_steps_progress_bar(tmpdir, limit_val_batches: int): +def test_num_sanity_val_steps_tqdm_progress_bar(tmpdir, limit_val_batches: int): """Test val_progress_bar total with 'num_sanity_val_steps' Trainer argument.""" class CurrentProgressBar(TQDMProgressBar): @@ -271,6 +277,7 @@ def on_validation_epoch_end(self, *args): assert progress_bar.val_pbar_total == limit_val_batches +@pytest.mark.skipif(_RICH_AVAILABLE, reason="Test requires TQDM progress bar as default.") def test_tqdm_progress_bar_default_value(tmpdir): """Test that a value of None defaults to refresh rate 1.""" trainer = Trainer(default_root_dir=tmpdir) @@ -280,6 +287,7 @@ def test_tqdm_progress_bar_default_value(tmpdir): assert trainer.progress_bar_callback.refresh_rate == 1 +@pytest.mark.skipif(_RICH_AVAILABLE, reason="Test requires TQDM progress bar as default.") @mock.patch.dict(os.environ, {"COLAB_GPU": "1"}) def test_tqdm_progress_bar_value_on_colab(tmpdir): """Test that Trainer will override the default in Google COLAB.""" @@ -368,7 +376,8 @@ def test_test_progress_bar_update_amount(tmpdir, test_batches: int, refresh_rate progress_bar.test_progress_bar.update.assert_has_calls([call(delta) for delta in test_deltas]) -def test_tensor_to_float_conversion(tmpdir): +@pytest.mark.parametrize("progress_bar_cls", (TQDMProgressBar, pytest.param(RichProgressBar, marks=RunIf(rich=True)))) +def test_tensor_to_float_conversion(progress_bar_cls, tmpdir): """Check tensor gets converted to float.""" class TestModel(BoringModel): @@ -379,7 +388,12 @@ def training_step(self, batch, batch_idx): return super().training_step(batch, batch_idx) trainer = Trainer( - default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, logger=False, enable_checkpointing=False + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=2, + logger=False, + enable_checkpointing=False, + callbacks=progress_bar_cls(), ) trainer.fit(TestModel()) @@ -387,8 +401,9 @@ def training_step(self, batch, batch_idx): assert trainer.progress_bar_metrics["b"] == {"b1": 1.0} assert trainer.progress_bar_metrics["c"] == {"c1": 2.0} pbar = trainer.progress_bar_callback.main_progress_bar - actual = str(pbar.postfix) - assert actual.endswith("a=0.123, b={'b1': 1.0}, c={'c1': 2.0}"), actual + if progress_bar_cls != RichProgressBar: + actual = str(pbar.postfix) + assert actual.endswith("a=0.123, b={'b1': 1.0}, c={'c1': 2.0}"), actual @pytest.mark.parametrize( @@ -507,8 +522,9 @@ def test_tqdm_progress_bar_print_disabled(tqdm_write, mock_print, tmpdir): tqdm_write.assert_not_called() -def test_tqdm_progress_bar_can_be_pickled(): - bar = TQDMProgressBar() +@pytest.mark.parametrize("progress_bar_cls", (TQDMProgressBar, pytest.param(RichProgressBar, marks=RunIf(rich=True)))) +def test_progress_bar_can_be_pickled(progress_bar_cls): + bar = progress_bar_cls() trainer = Trainer(fast_dev_run=True, callbacks=[bar], max_steps=1) model = BoringModel() @@ -566,8 +582,9 @@ def _test_progress_bar_max_val_check_interval( assert trainer.progress_bar_callback.main_progress_bar.total == total_train_batches + total_val_batches -def test_get_progress_bar_metrics(tmpdir: str): - class TestProgressBar(TQDMProgressBar): +@pytest.mark.parametrize("cls", (TQDMProgressBar, pytest.param(RichProgressBar, marks=RunIf(rich=True)))) +def test_get_progress_bar_metrics(cls, tmpdir: str): + class TestProgressBar(cls): def get_metrics(self, trainer: Trainer, model: LightningModule): items = super().get_metrics(trainer, model) items.pop("v_num", None) @@ -588,9 +605,10 @@ def get_metrics(self, trainer: Trainer, model: LightningModule): assert "v_num" not in standard_metrics.keys() -def test_tqdm_progress_bar_main_bar_resume(): +@pytest.mark.parametrize("cls", (TQDMProgressBar, pytest.param(RichProgressBar, marks=RunIf(rich=True)))) +def test_tqdm_progress_bar_main_bar_resume(cls): """Test that the progress bar can resume its counters based on the Trainer state.""" - bar = TQDMProgressBar() + bar = cls() trainer = Mock() model = Mock() @@ -606,7 +624,8 @@ def test_tqdm_progress_bar_main_bar_resume(): bar.on_train_start(trainer, model) bar.on_train_epoch_start(trainer, model) - assert bar.main_progress_bar.n == 3 + if not isinstance(bar, RichProgressBar): + assert bar.main_progress_bar.n == 3 assert bar.main_progress_bar.total == 8 # bar.on_train_epoch_end(trainer, model) @@ -614,5 +633,6 @@ def test_tqdm_progress_bar_main_bar_resume(): bar.on_validation_epoch_start(trainer, model) # restarting mid validation epoch is not currently supported - assert bar.val_progress_bar.n == 0 + if not isinstance(bar, RichProgressBar): + assert bar.val_progress_bar.n == 0 assert bar.val_progress_bar.total == 3 diff --git a/tests/deprecated_api/test_remove_1-7.py b/tests/deprecated_api/test_remove_1-7.py index 16c511b6effd9..92f092cd71f4a 100644 --- a/tests/deprecated_api/test_remove_1-7.py +++ b/tests/deprecated_api/test_remove_1-7.py @@ -107,7 +107,7 @@ def get_progress_bar_dict(self): test_model = TestModel() with pytest.deprecated_call(match=r"`LightningModule.get_progress_bar_dict` method was deprecated in v1.5"): trainer.fit(test_model) - standard_metrics_postfix = trainer.progress_bar_callback.main_progress_bar.postfix + standard_metrics_postfix = trainer.progress_bar_callback.get_metrics(trainer, test_model) assert "loss" in standard_metrics_postfix assert "v_num" not in standard_metrics_postfix