Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions docs/source-pytorch/fabric/fabric.rst
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,30 @@ Then, in your training loop, you can call a hook by its name. Any callback objec
fabric.call("on_train_epoch_end", results={...})


loggers
=======

Attach one or several loggers/experiment trackers to Fabric for convenient logging of metrics.

.. code-block:: python

# Default used by Fabric, no loggers are active
fabric = Fabric(loggers=[])

# Log to a single logger
fabric = Fabric(loggers=TensorBoardLogger(...))

# Or multiple instances
fabric = Fabric(loggers=[logger1, logger2, ...])

Anywhere in your training loop, you can log metrics to all loggers at once:

.. code-block:: python

fabric.log("loss", loss)
fabric.log_dict({"loss": loss, "accuracy": acc})


----------


Expand Down Expand Up @@ -613,3 +637,29 @@ It is useful when building a Trainer that allows the user to run arbitrary code

# Only the callbacks that have this method defined will be executed
fabric.call("undefined")


log and log_dict
================

These methods allows you to send scalar metrics to a logger registered in Fabric.

.. code-block:: python

# Set the logger in Fabric
fabric = Fabric(loggers=TensorBoardLogger(...))

# Anywhere in your training loop or model:
fabric.log("loss", loss)

# Or send multiple metrics at once:
fabric.log_dict({"loss": loss, "accuracy": acc})

If no loggers are given to Fabric (default), ``log`` and ``log_dict`` won't do anything.
Here is what's happening under the hood (pseudo code) when you call ``.log()`` or ``log_dict``:

.. code-block:: python

# When you call .log() or .log_dict(), we do this:
for logger in fabric.loggers:
logger.log_metrics(metrics=metrics, step=step)
1 change: 1 addition & 0 deletions requirements/fabric/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pytest==7.2.0
pytest-cov==4.0.0
pre-commit==2.20.0
click==8.1.3
tensorboard>=2.9.1, <2.12.0
2 changes: 0 additions & 2 deletions src/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
from lightning.fabric.utilities.seed import seed_everything # noqa: E402
from lightning.pytorch.callbacks import Callback # noqa: E402
from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402
from lightning.pytorch.lite import LightningLite # noqa: E402
from lightning.pytorch.trainer import Trainer # noqa: E402

import lightning.app # isort: skip # noqa: E402
Expand All @@ -61,7 +60,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None:
"LightningModule",
"Callback",
"seed_everything",
"LightningLite",
"Fabric",
"storage",
"pdb",
Expand Down
6 changes: 6 additions & 0 deletions src/lightning_fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074))

- Added Logger support ([#16121](https://github.com/Lightning-AI/lightning/issues/16121))
* Added `Fabric(loggers=...)` to support different Logger frameworks in Fabric
* Added `Fabric.log` for logging scalars using multiple loggers
* Added `Fabric.log_dict` for logging a dictionary of multiple metrics at once
* Added `Fabric.loggers` and `Fabric.logger` attributes to access the individual logger instances


- Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275))

Expand Down
43 changes: 42 additions & 1 deletion src/lightning_fabric/fabric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from lightning_fabric.loggers import Logger

from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split
from lightning_fabric.accelerators.accelerator import Accelerator
from lightning_fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT
Expand All @@ -47,7 +50,6 @@
has_iterable_dataset,
)
from lightning_fabric.utilities.distributed import DistributedSamplerWrapper
from lightning_fabric.utilities.rank_zero import rank_zero_warn
from lightning_fabric.utilities.seed import seed_everything
from lightning_fabric.utilities.warnings import PossibleUserWarning
from lightning_fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
Expand All @@ -74,6 +76,10 @@ class Fabric:
precision: Double precision (``64``), full precision (``32``), half precision (``16``),
or bfloat16 precision (``"bf16"``).
plugins: One or several custom plugins
callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that
can be invoked through :meth:`lightning_fabric.fabric.Fabric.call` by the user.
loggers: A single logger or a list of loggers. See :meth:`lightning_fabric.fabric.Fabric.log` for more
information.
"""

def __init__(
Expand All @@ -85,6 +91,7 @@ def __init__(
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
callbacks: Optional[Union[List[Any], Any]] = None,
loggers: Optional[Union[Logger, List[Logger]]] = None,
) -> None:
self._connector = _Connector(
accelerator=accelerator,
Expand All @@ -99,6 +106,8 @@ def __init__(
self._precision: Precision = self._strategy.precision
callbacks = callbacks if callbacks is not None else []
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
loggers = loggers if loggers is not None else []
self._loggers = loggers if isinstance(loggers, list) else [loggers]
self._models_setup: int = 0

self._prepare_run_method()
Expand Down Expand Up @@ -148,6 +157,16 @@ def is_global_zero(self) -> bool:
"""Whether this rank is rank zero."""
return self._strategy.is_global_zero

@property
def loggers(self) -> List[Logger]:
"""Returns all loggers passed to Fabric."""
return self._loggers

@property
def logger(self) -> Logger:
"""Returns the first logger in the list passed to Fabric, which is considered the main logger."""
return self._loggers[0]

def run(self, *args: Any, **kwargs: Any) -> Any:
"""All the code inside this run method gets accelerated by Fabric.

Expand Down Expand Up @@ -573,6 +592,28 @@ def on_train_epoch_end(self, results):
# method(self, *args, y=1)
# method(self, *args, **kwargs)

def log(self, name: str, value: Any, step: Optional[int] = None) -> None:
"""Log a scalar to all loggers that were added to Fabric.

Args:
name: The name of the metric to log.
value: The metric value to collect.
step: Optional step number. Most Logger implementations auto-increment the step value by one with every
log call. You can specify your own value here.
"""
self.log_dict(metrics={name: value}, step=step)

def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None:
"""Log multiple scalars at once to all loggers that were added to Fabric.

Args:
metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged.
step: Optional step number. Most Logger implementations auto-increment this value by one with every
log call. You can specify your own value here.
"""
for logger in self._loggers:
logger.log_metrics(metrics=metrics, step=step)

@staticmethod
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
"""Helper function to seed everything without explicitly importing Lightning.
Expand Down
14 changes: 14 additions & 0 deletions src/lightning_fabric/loggers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright The PyTorch Lightning team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_fabric.loggers.logger import Logger # noqa: F401
from lightning_fabric.loggers.tensorboard import TensorBoardLogger # noqa: F401
136 changes: 136 additions & 0 deletions src/lightning_fabric/loggers/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base class used to build new loggers."""

from abc import ABC, abstractmethod
from argparse import Namespace
from functools import wraps
from typing import Any, Callable, Dict, Optional, Union

from torch import Tensor
from torch.nn import Module

from lightning_fabric.utilities.rank_zero import rank_zero_only


class Logger(ABC):
"""Base class for experiment loggers."""

@property
@abstractmethod
def name(self) -> Optional[str]:
"""Return the experiment name."""

@property
@abstractmethod
def version(self) -> Optional[Union[int, str]]:
"""Return the experiment version."""

@property
def root_dir(self) -> Optional[str]:
"""Return the root directory where all versions of an experiment get saved, or `None` if the logger does
not save data locally."""
return None

@property
def log_dir(self) -> Optional[str]:
"""Return directory the current version of the experiment gets saved, or `None` if the logger does not save
data locally."""
return None

@property
def group_separator(self) -> str:
"""Return the default separator used by the logger to group the data into subfolders."""
return "/"

@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""Records metrics. This method logs metrics as soon as it received them.

Args:
metrics: Dictionary with metric names as keys and measured quantities as values
step: Step number at which the metrics should be recorded
"""
pass

@abstractmethod
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None:
"""Record hyperparameters.

Args:
params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters
args: Optional positional arguments, depends on the specific logger being used
kwargs: Optional keyword arguments, depends on the specific logger being used
"""

def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None:
"""Record model graph.

Args:
model: the model with an implementation of ``forward``.
input_array: input passes to `model.forward`
"""
pass

def save(self) -> None:
"""Save log data."""

def finalize(self, status: str) -> None:
"""Do any processing that is necessary to finalize an experiment.

Args:
status: Status that the experiment finished with (e.g. success, failed, aborted)
"""
self.save()


def rank_zero_experiment(fn: Callable) -> Callable:
"""Returns the real experiment on rank 0 and otherwise the _DummyExperiment."""

@wraps(fn)
def experiment(self) -> Union[Any, _DummyExperiment]: # type: ignore[no-untyped-def]
"""
Note:
``self`` is a custom logger instance. The loggers typically wrap an ``experiment`` method
with a ``@rank_zero_experiment`` decorator.

``Union[Any, _DummyExperiment]`` is used because the wrapped hooks have several return
types that are specific to the custom logger. The return type here can be considered as
``Union[return type of logger.experiment, _DummyExperiment]``.
"""

@rank_zero_only
def get_experiment() -> Callable:
return fn(self)

return get_experiment() or _DummyExperiment()

return experiment


class _DummyExperiment:
"""Dummy experiment."""

def nop(self, *args: Any, **kw: Any) -> None:
pass

def __getattr__(self, _: Any) -> Callable:
return self.nop

def __getitem__(self, idx: int) -> "_DummyExperiment":
# enables self.logger.experiment[0].add_image(...)
return self

def __setitem__(self, *args: Any, **kwargs: Any) -> None:
pass
Loading