Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Renamed `training_type_plugin` file to `strategy` ([#11239](https://github.com/PyTorchLightning/pytorch-lightning/pull/11239))


- Changed `DeviceStatsMonitor` to group metrics based on the logger's `group_separator` ([]())


### Deprecated

- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/callbacks/device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def on_train_batch_start(
return

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_start")
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
assert trainer.logger is not None
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)

Expand All @@ -75,10 +76,11 @@ def on_train_batch_end(
return

device_stats = trainer.accelerator.get_device_stats(pl_module.device)
prefixed_device_stats = prefix_metrics_keys(device_stats, "on_train_batch_end")
separator = trainer.logger.group_separator
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
assert trainer.logger is not None
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)


def prefix_metrics_keys(metrics_dict: Dict[str, float], prefix: str) -> Dict[str, float]:
return {prefix + "." + k: v for k, v in metrics_dict.items()}
def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:
return {prefix + separator + k: v for k, v in metrics_dict.items()}
10 changes: 10 additions & 0 deletions tests/callbacks/test_device_stats_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import DeviceStatsMonitor
from pytorch_lightning.callbacks.device_stats_monitor import _prefix_metric_keys
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -128,3 +129,12 @@ def test_device_stats_monitor_no_logger(tmpdir):

with pytest.raises(MisconfigurationException, match="Trainer that has no logger."):
trainer.fit(model)


def test_prefix_metric_keys(tmpdir):
"""Test that metric key names are converted correctly."""
metrics = {"1": 1.0, "2": 2.0, "3": 3.0}
prefix = "foo"
separator = "."
converted_metrics = _prefix_metric_keys(metrics, prefix, separator)
assert converted_metrics == {"foo.1": 1.0, "foo.2": 2.0, "foo.3": 3.0}