Skip to content

Commit f24349b

Browse files
Logger support in Lite (#16121)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 2166ce2 commit f24349b

File tree

15 files changed

+884
-258
lines changed

15 files changed

+884
-258
lines changed

docs/source-pytorch/fabric/fabric.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,30 @@ Then, in your training loop, you can call a hook by its name. Any callback objec
385385
fabric.call("on_train_epoch_end", results={...})
386386
387387
388+
loggers
389+
=======
390+
391+
Attach one or several loggers/experiment trackers to Fabric for convenient logging of metrics.
392+
393+
.. code-block:: python
394+
395+
# Default used by Fabric, no loggers are active
396+
fabric = Fabric(loggers=[])
397+
398+
# Log to a single logger
399+
fabric = Fabric(loggers=TensorBoardLogger(...))
400+
401+
# Or multiple instances
402+
fabric = Fabric(loggers=[logger1, logger2, ...])
403+
404+
Anywhere in your training loop, you can log metrics to all loggers at once:
405+
406+
.. code-block:: python
407+
408+
fabric.log("loss", loss)
409+
fabric.log_dict({"loss": loss, "accuracy": acc})
410+
411+
388412
----------
389413

390414

@@ -613,3 +637,29 @@ It is useful when building a Trainer that allows the user to run arbitrary code
613637
614638
# Only the callbacks that have this method defined will be executed
615639
fabric.call("undefined")
640+
641+
642+
log and log_dict
643+
================
644+
645+
These methods allows you to send scalar metrics to a logger registered in Fabric.
646+
647+
.. code-block:: python
648+
649+
# Set the logger in Fabric
650+
fabric = Fabric(loggers=TensorBoardLogger(...))
651+
652+
# Anywhere in your training loop or model:
653+
fabric.log("loss", loss)
654+
655+
# Or send multiple metrics at once:
656+
fabric.log_dict({"loss": loss, "accuracy": acc})
657+
658+
If no loggers are given to Fabric (default), ``log`` and ``log_dict`` won't do anything.
659+
Here is what's happening under the hood (pseudo code) when you call ``.log()`` or ``log_dict``:
660+
661+
.. code-block:: python
662+
663+
# When you call .log() or .log_dict(), we do this:
664+
for logger in fabric.loggers:
665+
logger.log_metrics(metrics=metrics, step=step)

requirements/fabric/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ pytest==7.2.0
44
pytest-cov==4.0.0
55
pre-commit==2.20.0
66
click==8.1.3
7+
tensorboard>=2.9.1, <2.12.0

src/lightning/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
4040
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
4141
from lightning.pytorch.callbacks import Callback # noqa: E402
4242
from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402
43-
from lightning.pytorch.lite import LightningLite # noqa: E402
4443
from lightning.pytorch.trainer import Trainer # noqa: E402
4544

4645
import lightning.app # isort: skip # noqa: E402
@@ -61,7 +60,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
6160
"LightningModule",
6261
"Callback",
6362
"seed_everything",
64-
"LightningLite",
6563
"Fabric",
6664
"storage",
6765
"pdb",

src/lightning_fabric/CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030

3131
- Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074))
3232

33+
- Added Logger support ([#16121](https://github.com/Lightning-AI/lightning/issues/16121))
34+
* Added `Fabric(loggers=...)` to support different Logger frameworks in Fabric
35+
* Added `Fabric.log` for logging scalars using multiple loggers
36+
* Added `Fabric.log_dict` for logging a dictionary of multiple metrics at once
37+
* Added `Fabric.loggers` and `Fabric.logger` attributes to access the individual logger instances
38+
3339

3440
- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))
3541

src/lightning_fabric/fabric.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@
2222
import torch.nn as nn
2323
from lightning_utilities.core.apply_func import apply_to_collection
2424
from lightning_utilities.core.overrides import is_overridden
25+
from lightning_utilities.core.rank_zero import rank_zero_warn
2526
from torch import Tensor
2627
from torch.optim import Optimizer
2728
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
2829

30+
from lightning_fabric.loggers import Logger
31+
2932
from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split
3033
from lightning_fabric.accelerators.accelerator import Accelerator
3134
from lightning_fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
@@ -47,7 +50,6 @@
4750
has_iterable_dataset,
4851
)
4952
from lightning_fabric.utilities.distributed import DistributedSamplerWrapper
50-
from lightning_fabric.utilities.rank_zero import rank_zero_warn
5153
from lightning_fabric.utilities.seed import seed_everything
5254
from lightning_fabric.utilities.warnings import PossibleUserWarning
5355
from lightning_fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
@@ -74,6 +76,10 @@ class Fabric:
7476
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
7577
or bfloat16 precision (``"bf16"``).
7678
plugins: One or several custom plugins
79+
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
80+
can be invoked through :meth:`lightning_fabric.fabric.Fabric.call` by the user.
81+
loggers: A single logger or a list of loggers. See :meth:`lightning_fabric.fabric.Fabric.log` for more
82+
information.
7783
"""
7884

7985
def __init__(
@@ -85,6 +91,7 @@ def __init__(
8591
precision: _PRECISION_INPUT = 32,
8692
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
8793
callbacks: Optional[Union[List[Any], Any]] = None,
94+
loggers: Optional[Union[Logger, List[Logger]]] = None,
8895
) -> None:
8996
self._connector = _Connector(
9097
accelerator=accelerator,
@@ -99,6 +106,8 @@ def __init__(
99106
self._precision: Precision = self._strategy.precision
100107
callbacks = callbacks if callbacks is not None else []
101108
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
109+
loggers = loggers if loggers is not None else []
110+
self._loggers = loggers if isinstance(loggers, list) else [loggers]
102111
self._models_setup: int = 0
103112

104113
self._prepare_run_method()
@@ -148,6 +157,16 @@ def is_global_zero(self) -> bool:
148157
"""Whether this rank is rank zero."""
149158
return self._strategy.is_global_zero
150159

160+
@property
161+
def loggers(self) -> List[Logger]:
162+
"""Returns all loggers passed to Fabric."""
163+
return self._loggers
164+
165+
@property
166+
def logger(self) -> Logger:
167+
"""Returns the first logger in the list passed to Fabric, which is considered the main logger."""
168+
return self._loggers[0]
169+
151170
def run(self, *args: Any, **kwargs: Any) -> Any:
152171
"""All the code inside this run method gets accelerated by Fabric.
153172
@@ -573,6 +592,28 @@ def on_train_epoch_end(self, results):
573592
# method(self, *args, y=1)
574593
# method(self, *args, **kwargs)
575594

595+
def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
596+
"""Log a scalar to all loggers that were added to Fabric.
597+
598+
Args:
599+
name: The name of the metric to log.
600+
value: The metric value to collect.
601+
step: Optional step number. Most Logger implementations auto-increment the step value by one with every
602+
log call. You can specify your own value here.
603+
"""
604+
self.log_dict(metrics={name: value}, step=step)
605+
606+
def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
607+
"""Log multiple scalars at once to all loggers that were added to Fabric.
608+
609+
Args:
610+
metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged.
611+
step: Optional step number. Most Logger implementations auto-increment this value by one with every
612+
log call. You can specify your own value here.
613+
"""
614+
for logger in self._loggers:
615+
logger.log_metrics(metrics=metrics, step=step)
616+
576617
@staticmethod
577618
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
578619
"""Helper function to seed everything without explicitly importing Lightning.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright The PyTorch Lightning team.
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
#
6+
# http://www.apache.org/licenses/LICENSE-2.0
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
from lightning_fabric.loggers.logger import Logger # noqa: F401
14+
from lightning_fabric.loggers.tensorboard import TensorBoardLogger # noqa: F401
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Abstract base class used to build new loggers."""
15+
16+
from abc import ABC, abstractmethod
17+
from argparse import Namespace
18+
from functools import wraps
19+
from typing import Any, Callable, Dict, Optional, Union
20+
21+
from torch import Tensor
22+
from torch.nn import Module
23+
24+
from lightning_fabric.utilities.rank_zero import rank_zero_only
25+
26+
27+
class Logger(ABC):
28+
"""Base class for experiment loggers."""
29+
30+
@property
31+
@abstractmethod
32+
def name(self) -> Optional[str]:
33+
"""Return the experiment name."""
34+
35+
@property
36+
@abstractmethod
37+
def version(self) -> Optional[Union[int, str]]:
38+
"""Return the experiment version."""
39+
40+
@property
41+
def root_dir(self) -> Optional[str]:
42+
"""Return the root directory where all versions of an experiment get saved, or `None` if the logger does
43+
not save data locally."""
44+
return None
45+
46+
@property
47+
def log_dir(self) -> Optional[str]:
48+
"""Return directory the current version of the experiment gets saved, or `None` if the logger does not save
49+
data locally."""
50+
return None
51+
52+
@property
53+
def group_separator(self) -> str:
54+
"""Return the default separator used by the logger to group the data into subfolders."""
55+
return "/"
56+
57+
@abstractmethod
58+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
59+
"""Records metrics. This method logs metrics as soon as it received them.
60+
61+
Args:
62+
metrics: Dictionary with metric names as keys and measured quantities as values
63+
step: Step number at which the metrics should be recorded
64+
"""
65+
pass
66+
67+
@abstractmethod
68+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None:
69+
"""Record hyperparameters.
70+
71+
Args:
72+
params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters
73+
args: Optional positional arguments, depends on the specific logger being used
74+
kwargs: Optional keyword arguments, depends on the specific logger being used
75+
"""
76+
77+
def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
78+
"""Record model graph.
79+
80+
Args:
81+
model: the model with an implementation of ``forward``.
82+
input_array: input passes to `model.forward`
83+
"""
84+
pass
85+
86+
def save(self) -> None:
87+
"""Save log data."""
88+
89+
def finalize(self, status: str) -> None:
90+
"""Do any processing that is necessary to finalize an experiment.
91+
92+
Args:
93+
status: Status that the experiment finished with (e.g. success, failed, aborted)
94+
"""
95+
self.save()
96+
97+
98+
def rank_zero_experiment(fn: Callable) -> Callable:
99+
"""Returns the real experiment on rank 0 and otherwise the _DummyExperiment."""
100+
101+
@wraps(fn)
102+
def experiment(self) -> Union[Any, _DummyExperiment]: # type: ignore[no-untyped-def]
103+
"""
104+
Note:
105+
``self`` is a custom logger instance. The loggers typically wrap an ``experiment`` method
106+
with a ``@rank_zero_experiment`` decorator.
107+
108+
``Union[Any, _DummyExperiment]`` is used because the wrapped hooks have several return
109+
types that are specific to the custom logger. The return type here can be considered as
110+
``Union[return type of logger.experiment, _DummyExperiment]``.
111+
"""
112+
113+
@rank_zero_only
114+
def get_experiment() -> Callable:
115+
return fn(self)
116+
117+
return get_experiment() or _DummyExperiment()
118+
119+
return experiment
120+
121+
122+
class _DummyExperiment:
123+
"""Dummy experiment."""
124+
125+
def nop(self, *args: Any, **kw: Any) -> None:
126+
pass
127+
128+
def __getattr__(self, _: Any) -> Callable:
129+
return self.nop
130+
131+
def __getitem__(self, idx: int) -> "_DummyExperiment":
132+
# enables self.logger.experiment[0].add_image(...)
133+
return self
134+
135+
def __setitem__(self, *args: Any, **kwargs: Any) -> None:
136+
pass

0 commit comments

Comments
 (0)