Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,27 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-


### Changed

- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored.
- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611))


- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538))


### Removed

-


### Fixed

- Fix CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594))


- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610))


Expand Down
9 changes: 6 additions & 3 deletions src/lightning/pytorch/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]:
if isinstance(self._experiment, (Run, RunDisabled)) and getattr(
self._experiment, "define_metric", None
):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
if self._wandb_init.get("sync_tensorboard"):
self._experiment.define_metric("*", step_metric="global_step")
else:
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)

return self._experiment

Expand All @@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None)
assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0"

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
if step is not None:
if step is not None and not self._wandb_init.get("sync_tensorboard"):
self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
else:
self.experiment.log(metrics)
Expand Down
1 change: 1 addition & 0 deletions tests/tests_pytorch/loggers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class RunType: # to make isinstance checks pass
watch=Mock(),
log_artifact=Mock(),
use_artifact=Mock(),
define_metric=Mock(),
id="run_id",
)

Expand Down
18 changes: 18 additions & 0 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ def test_wandb_logger_init(wandb_mock):
assert logger.version == wandb_mock.init().id


def test_wandb_logger_sync_tensorboard(wandb_mock):
logger = WandbLogger(sync_tensorboard=True)
wandb_mock.run = None
logger.experiment

# test that tensorboard's global_step is set as the default x-axis if sync_tensorboard=True
wandb_mock.init.return_value.define_metric.assert_called_once_with("*", step_metric="global_step")


def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock):
logger = WandbLogger(sync_tensorboard=True)
metrics = {"loss": 1e-3, "accuracy": 0.99}
logger.log_metrics(metrics)

# test that trainer/global_step is not added to the logged metrics if sync_tensorboard=True
wandb_mock.run.log.assert_called_once_with(metrics)


def test_wandb_logger_init_before_spawn(wandb_mock):
logger = WandbLogger()
assert logger._experiment is None
Expand Down
Loading