diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index eeed4cef5e5cc..62f43d57f5d1f 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -8,16 +8,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- + + ### Changed -- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored. +- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611)) + + +- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538)) + ### Removed +- + + ### Fixed - Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) + - Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610)) diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 2429748f73179..0ea32b97c46d1 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]: if isinstance(self._experiment, (Run, RunDisabled)) and getattr( self._experiment, "define_metric", None ): - self._experiment.define_metric("trainer/global_step") - self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) + if self._wandb_init.get("sync_tensorboard"): + self._experiment.define_metric("*", step_metric="global_step") + else: + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) return self._experiment @@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - if step is not None: + if step is not None and not self._wandb_init.get("sync_tensorboard"): self.experiment.log(dict(metrics, **{"trainer/global_step": step})) else: self.experiment.log(metrics) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index ab1149ca9651a..033275a9fec62 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -55,6 +55,7 @@ class RunType: # to make isinstance checks pass watch=Mock(), log_artifact=Mock(), use_artifact=Mock(), + define_metric=Mock(), id="run_id", ) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index 52ad03bd994b4..7b20423380cb1 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -126,6 +126,24 @@ def test_wandb_logger_init(wandb_mock): assert logger.version == wandb_mock.init().id +def test_wandb_logger_sync_tensorboard(wandb_mock): + logger = WandbLogger(sync_tensorboard=True) + wandb_mock.run = None + logger.experiment + + # test that tensorboard's global_step is set as the default x-axis if sync_tensorboard=True + wandb_mock.init.return_value.define_metric.assert_called_once_with("*", step_metric="global_step") + + +def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock): + logger = WandbLogger(sync_tensorboard=True) + metrics = {"loss": 1e-3, "accuracy": 0.99} + logger.log_metrics(metrics) + + # test that trainer/global_step is not added to the logged metrics if sync_tensorboard=True + wandb_mock.run.log.assert_called_once_with(metrics) + + def test_wandb_logger_init_before_spawn(wandb_mock): logger = WandbLogger() assert logger._experiment is None