diff --git a/docs/source-pytorch/fabric/fabric.rst b/docs/source-pytorch/fabric/fabric.rst index e87a5d82e323d..a7ee0d720104f 100644 --- a/docs/source-pytorch/fabric/fabric.rst +++ b/docs/source-pytorch/fabric/fabric.rst @@ -385,6 +385,30 @@ Then, in your training loop, you can call a hook by its name. Any callback objec fabric.call("on_train_epoch_end", results={...}) +loggers +======= + +Attach one or several loggers/experiment trackers to Fabric for convenient logging of metrics. + +.. code-block:: python + + # Default used by Fabric, no loggers are active + fabric = Fabric(loggers=[]) + + # Log to a single logger + fabric = Fabric(loggers=TensorBoardLogger(...)) + + # Or multiple instances + fabric = Fabric(loggers=[logger1, logger2, ...]) + +Anywhere in your training loop, you can log metrics to all loggers at once: + +.. code-block:: python + + fabric.log("loss", loss) + fabric.log_dict({"loss": loss, "accuracy": acc}) + + ---------- @@ -613,3 +637,29 @@ It is useful when building a Trainer that allows the user to run arbitrary code # Only the callbacks that have this method defined will be executed fabric.call("undefined") + + +log and log_dict +================ + +These methods allows you to send scalar metrics to a logger registered in Fabric. + +.. code-block:: python + + # Set the logger in Fabric + fabric = Fabric(loggers=TensorBoardLogger(...)) + + # Anywhere in your training loop or model: + fabric.log("loss", loss) + + # Or send multiple metrics at once: + fabric.log_dict({"loss": loss, "accuracy": acc}) + +If no loggers are given to Fabric (default), ``log`` and ``log_dict`` won't do anything. +Here is what's happening under the hood (pseudo code) when you call ``.log()`` or ``log_dict``: + +.. code-block:: python + + # When you call .log() or .log_dict(), we do this: + for logger in fabric.loggers: + logger.log_metrics(metrics=metrics, step=step) diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt index abb2cf558488a..c4e1cf8866576 100644 --- a/requirements/fabric/test.txt +++ b/requirements/fabric/test.txt @@ -4,3 +4,4 @@ pytest==7.2.0 pytest-cov==4.0.0 pre-commit==2.20.0 click==8.1.3 +tensorboard>=2.9.1, <2.12.0 diff --git a/src/lightning/__init__.py b/src/lightning/__init__.py index a0d5835536792..b66e83a1e2b17 100644 --- a/src/lightning/__init__.py +++ b/src/lightning/__init__.py @@ -40,7 +40,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None: from lightning.fabric.utilities.seed import seed_everything # noqa: E402 from lightning.pytorch.callbacks import Callback # noqa: E402 from lightning.pytorch.core import LightningDataModule, LightningModule # noqa: E402 -from lightning.pytorch.lite import LightningLite # noqa: E402 from lightning.pytorch.trainer import Trainer # noqa: E402 import lightning.app # isort: skip # noqa: E402 @@ -61,7 +60,6 @@ def _detail(self: Any, message: str, *args: Any, **kwargs: Any) -> None: "LightningModule", "Callback", "seed_everything", - "LightningLite", "Fabric", "storage", "pdb", diff --git a/src/lightning_fabric/CHANGELOG.md b/src/lightning_fabric/CHANGELOG.md index afb900294f64b..0a51fc3858cc4 100644 --- a/src/lightning_fabric/CHANGELOG.md +++ b/src/lightning_fabric/CHANGELOG.md @@ -30,6 +30,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for managing callbacks via `Fabric(callbacks=...)` and emitting events through `Fabric.call()` ([#16074](https://github.com/Lightning-AI/lightning/issues/16074)) +- Added Logger support ([#16121](https://github.com/Lightning-AI/lightning/issues/16121)) + * Added `Fabric(loggers=...)` to support different Logger frameworks in Fabric + * Added `Fabric.log` for logging scalars using multiple loggers + * Added `Fabric.log_dict` for logging a dictionary of multiple metrics at once + * Added `Fabric.loggers` and `Fabric.logger` attributes to access the individual logger instances + - Added support for a consistent `.zero_grad(set_to_none=...)` on the wrapped optimizer regardless of which strategy is used ([#16275](https://github.com/Lightning-AI/lightning/issues/16275)) diff --git a/src/lightning_fabric/fabric.py b/src/lightning_fabric/fabric.py index 75ba50e2b5dfa..fbfe572ebbdb5 100644 --- a/src/lightning_fabric/fabric.py +++ b/src/lightning_fabric/fabric.py @@ -22,10 +22,13 @@ import torch.nn as nn from lightning_utilities.core.apply_func import apply_to_collection from lightning_utilities.core.overrides import is_overridden +from lightning_utilities.core.rank_zero import rank_zero_warn from torch import Tensor from torch.optim import Optimizer from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler +from lightning_fabric.loggers import Logger + from lightning_fabric.plugins import Precision # avoid circular imports: # isort: split from lightning_fabric.accelerators.accelerator import Accelerator from lightning_fabric.connector import _Connector, _PLUGIN_INPUT, _PRECISION_INPUT @@ -47,7 +50,6 @@ has_iterable_dataset, ) from lightning_fabric.utilities.distributed import DistributedSamplerWrapper -from lightning_fabric.utilities.rank_zero import rank_zero_warn from lightning_fabric.utilities.seed import seed_everything from lightning_fabric.utilities.warnings import PossibleUserWarning from lightning_fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer @@ -74,6 +76,10 @@ class Fabric: precision: Double precision (``64``), full precision (``32``), half precision (``16``), or bfloat16 precision (``"bf16"``). plugins: One or several custom plugins + callbacks: A single callback or a list of callbacks. A callback can contain any arbitrary methods that + can be invoked through :meth:`lightning_fabric.fabric.Fabric.call` by the user. + loggers: A single logger or a list of loggers. See :meth:`lightning_fabric.fabric.Fabric.log` for more + information. """ def __init__( @@ -85,6 +91,7 @@ def __init__( precision: _PRECISION_INPUT = 32, plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None, callbacks: Optional[Union[List[Any], Any]] = None, + loggers: Optional[Union[Logger, List[Logger]]] = None, ) -> None: self._connector = _Connector( accelerator=accelerator, @@ -99,6 +106,8 @@ def __init__( self._precision: Precision = self._strategy.precision callbacks = callbacks if callbacks is not None else [] self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks] + loggers = loggers if loggers is not None else [] + self._loggers = loggers if isinstance(loggers, list) else [loggers] self._models_setup: int = 0 self._prepare_run_method() @@ -148,6 +157,16 @@ def is_global_zero(self) -> bool: """Whether this rank is rank zero.""" return self._strategy.is_global_zero + @property + def loggers(self) -> List[Logger]: + """Returns all loggers passed to Fabric.""" + return self._loggers + + @property + def logger(self) -> Logger: + """Returns the first logger in the list passed to Fabric, which is considered the main logger.""" + return self._loggers[0] + def run(self, *args: Any, **kwargs: Any) -> Any: """All the code inside this run method gets accelerated by Fabric. @@ -573,6 +592,28 @@ def on_train_epoch_end(self, results): # method(self, *args, y=1) # method(self, *args, **kwargs) + def log(self, name: str, value: Any, step: Optional[int] = None) -> None: + """Log a scalar to all loggers that were added to Fabric. + + Args: + name: The name of the metric to log. + value: The metric value to collect. + step: Optional step number. Most Logger implementations auto-increment the step value by one with every + log call. You can specify your own value here. + """ + self.log_dict(metrics={name: value}, step=step) + + def log_dict(self, metrics: Dict[str, Any], step: Optional[int] = None) -> None: + """Log multiple scalars at once to all loggers that were added to Fabric. + + Args: + metrics: A dictionary where the key is the name of the metric and the value the scalar to be logged. + step: Optional step number. Most Logger implementations auto-increment this value by one with every + log call. You can specify your own value here. + """ + for logger in self._loggers: + logger.log_metrics(metrics=metrics, step=step) + @staticmethod def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int: """Helper function to seed everything without explicitly importing Lightning. diff --git a/src/lightning_fabric/loggers/__init__.py b/src/lightning_fabric/loggers/__init__.py new file mode 100644 index 0000000000000..03c21d71f8304 --- /dev/null +++ b/src/lightning_fabric/loggers/__init__.py @@ -0,0 +1,14 @@ +# Copyright The PyTorch Lightning team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from lightning_fabric.loggers.logger import Logger # noqa: F401 +from lightning_fabric.loggers.tensorboard import TensorBoardLogger # noqa: F401 diff --git a/src/lightning_fabric/loggers/logger.py b/src/lightning_fabric/loggers/logger.py new file mode 100644 index 0000000000000..a66023eb747d1 --- /dev/null +++ b/src/lightning_fabric/loggers/logger.py @@ -0,0 +1,136 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Abstract base class used to build new loggers.""" + +from abc import ABC, abstractmethod +from argparse import Namespace +from functools import wraps +from typing import Any, Callable, Dict, Optional, Union + +from torch import Tensor +from torch.nn import Module + +from lightning_fabric.utilities.rank_zero import rank_zero_only + + +class Logger(ABC): + """Base class for experiment loggers.""" + + @property + @abstractmethod + def name(self) -> Optional[str]: + """Return the experiment name.""" + + @property + @abstractmethod + def version(self) -> Optional[Union[int, str]]: + """Return the experiment version.""" + + @property + def root_dir(self) -> Optional[str]: + """Return the root directory where all versions of an experiment get saved, or `None` if the logger does + not save data locally.""" + return None + + @property + def log_dir(self) -> Optional[str]: + """Return directory the current version of the experiment gets saved, or `None` if the logger does not save + data locally.""" + return None + + @property + def group_separator(self) -> str: + """Return the default separator used by the logger to group the data into subfolders.""" + return "/" + + @abstractmethod + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + """Records metrics. This method logs metrics as soon as it received them. + + Args: + metrics: Dictionary with metric names as keys and measured quantities as values + step: Step number at which the metrics should be recorded + """ + pass + + @abstractmethod + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: + """Record hyperparameters. + + Args: + params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters + args: Optional positional arguments, depends on the specific logger being used + kwargs: Optional keyword arguments, depends on the specific logger being used + """ + + def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: + """Record model graph. + + Args: + model: the model with an implementation of ``forward``. + input_array: input passes to `model.forward` + """ + pass + + def save(self) -> None: + """Save log data.""" + + def finalize(self, status: str) -> None: + """Do any processing that is necessary to finalize an experiment. + + Args: + status: Status that the experiment finished with (e.g. success, failed, aborted) + """ + self.save() + + +def rank_zero_experiment(fn: Callable) -> Callable: + """Returns the real experiment on rank 0 and otherwise the _DummyExperiment.""" + + @wraps(fn) + def experiment(self) -> Union[Any, _DummyExperiment]: # type: ignore[no-untyped-def] + """ + Note: + ``self`` is a custom logger instance. The loggers typically wrap an ``experiment`` method + with a ``@rank_zero_experiment`` decorator. + + ``Union[Any, _DummyExperiment]`` is used because the wrapped hooks have several return + types that are specific to the custom logger. The return type here can be considered as + ``Union[return type of logger.experiment, _DummyExperiment]``. + """ + + @rank_zero_only + def get_experiment() -> Callable: + return fn(self) + + return get_experiment() or _DummyExperiment() + + return experiment + + +class _DummyExperiment: + """Dummy experiment.""" + + def nop(self, *args: Any, **kw: Any) -> None: + pass + + def __getattr__(self, _: Any) -> Callable: + return self.nop + + def __getitem__(self, idx: int) -> "_DummyExperiment": + # enables self.logger.experiment[0].add_image(...) + return self + + def __setitem__(self, *args: Any, **kwargs: Any) -> None: + pass diff --git a/src/lightning_fabric/loggers/tensorboard.py b/src/lightning_fabric/loggers/tensorboard.py new file mode 100644 index 0000000000000..ca694d9ea30c5 --- /dev/null +++ b/src/lightning_fabric/loggers/tensorboard.py @@ -0,0 +1,310 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from argparse import Namespace +from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union + +import numpy as np +from lightning_utilities.core.imports import RequirementCache +from torch import Tensor +from torch.nn import Module + +from lightning_fabric.loggers.logger import Logger, rank_zero_experiment +from lightning_fabric.utilities.cloud_io import get_filesystem +from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from lightning_fabric.utilities.logger import _sanitize_params as _utils_sanitize_params +from lightning_fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn +from lightning_fabric.utilities.types import _PATH + +log = logging.getLogger(__name__) + +_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard") +_TENSORBOARDX_AVAILABLE = RequirementCache("tensorboardX") +if TYPE_CHECKING: + # assumes at least one will be installed when type checking + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard import SummaryWriter + else: + from tensorboardX import SummaryWriter # type: ignore[no-redef] + + +class TensorBoardLogger(Logger): + r""" + Log to local file system in `TensorBoard `_ format. + + Implemented using :class:`~tensorboardX.SummaryWriter`. Logs are saved to + ``os.path.join(root_dir, name, version)``. This is the recommended logger in Lightning Fabric. + + Args: + root_dir: The root directory in which all your experiments with different names and versions will be stored. + name: Experiment name. Defaults to ``'lightning_logs'``. If it is the empty string then no per-experiment + subdirectory is used. + version: Experiment version. If version is not specified the logger inspects the save + directory for existing versions, then automatically assigns the next available version. + If it is a string then it is used as the run-specific subdirectory name, + otherwise ``'version_${version}'`` is used. + default_hp_metric: Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is + called without a metric (otherwise calls to ``log_hyperparams`` without a metric are ignored). + prefix: A string to put at the beginning of all metric keys. + sub_dir: Sub-directory to group TensorBoard logs. If a ``sub_dir`` argument is passed + then logs are saved in ``/root_dir/name/version/sub_dir/``. Defaults to ``None`` in which case + logs are saved in ``/root_dir/name/version/``. + \**kwargs: Additional arguments used by :class:`tensorboardX.SummaryWriter` can be passed as keyword + arguments in this logger. To automatically flush to disk, `max_queue` sets the size + of the queue for pending logs before flushing. `flush_secs` determines how many seconds + elapses before flushing. + + + Example:: + + from lightning.fabric.loggers import TensorBoardLogger + + logger = TensorBoardLogger("path/to/logs/rot", name="my_model") + logger.log_hyperparams({"epochs": 5, "optimizer": "Adam"}) + logger.log_metrics({"acc": 0.75}) + logger.finalize("success") + """ + LOGGER_JOIN_CHAR = "-" + + def __init__( + self, + root_dir: _PATH, + name: Optional[str] = "lightning_logs", + version: Optional[Union[int, str]] = None, + default_hp_metric: bool = True, + prefix: str = "", + sub_dir: Optional[_PATH] = None, + **kwargs: Any, + ): + if not _TENSORBOARD_AVAILABLE and not _TENSORBOARDX_AVAILABLE: + raise ModuleNotFoundError( + "Neither `tensorboard` nor `tensorboardX` is available. Try `pip install`ing either." + ) + super().__init__() + root_dir = os.fspath(root_dir) + self._root_dir = root_dir + self._name = name or "" + self._version = version + self._sub_dir = None if sub_dir is None else os.fspath(sub_dir) + + self._default_hp_metric = default_hp_metric + self._prefix = prefix + self._fs = get_filesystem(root_dir) + + self._experiment: Optional["SummaryWriter"] = None + self._kwargs = kwargs + + @property + def name(self) -> str: + """Get the name of the experiment. + + Returns: + The name of the experiment. + """ + return self._name + + @property + def version(self) -> Union[int, str]: + """Get the experiment version. + + Returns: + The experiment version if specified else the next version. + """ + if self._version is None: + self._version = self._get_next_version() + return self._version + + @property + def root_dir(self) -> str: + """Gets the save directory where the TensorBoard experiments are saved. + + Returns: + The local path to the save directory where the TensorBoard experiments are saved. + """ + return self._root_dir + + @property + def log_dir(self) -> str: + """The directory for this run's tensorboard checkpoint. + + By default, it is named ``'version_${self.version}'`` but it can be overridden by passing a string value for the + constructor's version parameter instead of ``None`` or an int. + """ + version = self.version if isinstance(self.version, str) else f"version_{self.version}" + log_dir = os.path.join(self.root_dir, self.name, version) + if isinstance(self.sub_dir, str): + log_dir = os.path.join(log_dir, self.sub_dir) + log_dir = os.path.expandvars(log_dir) + log_dir = os.path.expanduser(log_dir) + return log_dir + + @property + def sub_dir(self) -> Optional[str]: + """Gets the sub directory where the TensorBoard experiments are saved. + + Returns: + The local path to the sub directory where the TensorBoard experiments are saved. + """ + return self._sub_dir + + @property + @rank_zero_experiment + def experiment(self) -> "SummaryWriter": + """Actual tensorboard object. To use TensorBoard features anywhere in your code, do the following. + + Example:: + + logger.experiment.some_tensorboard_function() + """ + if self._experiment is not None: + return self._experiment + + assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0" + if self.root_dir: + self._fs.makedirs(self.root_dir, exist_ok=True) + + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard import SummaryWriter + else: + from tensorboardX import SummaryWriter # type: ignore[no-redef] + + self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) + return self._experiment + + @rank_zero_only + def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: + assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" + + metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) + + for k, v in metrics.items(): + if isinstance(v, Tensor): + v = v.item() + + if isinstance(v, dict): + self.experiment.add_scalars(k, v, step) + else: + try: + self.experiment.add_scalar(k, v, step) + # TODO(fabric): specify the possible exception + except Exception as ex: + m = f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor." + raise ValueError(m) from ex + + @rank_zero_only + def log_hyperparams( + self, params: Union[Dict[str, Any], Namespace], metrics: Optional[Dict[str, Any]] = None + ) -> None: + """Record hyperparameters. TensorBoard logs with and without saved hyperparameters are incompatible, the + hyperparameters are then not displayed in the TensorBoard. Please delete or move the previously saved logs + to display the new ones with hyperparameters. + + Args: + params: a dictionary-like container with the hyperparameters + metrics: Dictionary with metric names as keys and measured quantities as values + """ + params = _convert_params(params) + + # format params into the suitable for tensorboard + params = _flatten_dict(params) + params = self._sanitize_params(params) + + if metrics is None: + if self._default_hp_metric: + metrics = {"hp_metric": -1} + elif not isinstance(metrics, dict): + metrics = {"hp_metric": metrics} + + if metrics: + self.log_metrics(metrics, 0) + + if _TENSORBOARD_AVAILABLE: + from torch.utils.tensorboard.summary import hparams + else: + from tensorboardX.summary import hparams # type: ignore[no-redef] + + exp, ssi, sei = hparams(params, metrics) + writer = self.experiment._get_file_writer() + writer.add_summary(exp) + writer.add_summary(ssi) + writer.add_summary(sei) + + @rank_zero_only + def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: + model_example_input = getattr(model, "example_input_array", None) + input_array = model_example_input if input_array is None else input_array + + if input_array is None: + rank_zero_warn( + "Could not log computational graph to TensorBoard: The `model.example_input_array` attribute" + " is not set or `input_array` was not given." + ) + elif not isinstance(input_array, (Tensor, tuple)): + rank_zero_warn( + "Could not log computational graph to TensorBoard: The `input_array` or `model.example_input_array`" + f" has type {type(input_array)} which can't be traced by TensorBoard. Make the input array a tuple" + f" representing the positional arguments to the model's `forward()` implementation." + ) + elif callable(getattr(model, "_on_before_batch_transfer", None)) and callable( + getattr(model, "_apply_batch_transfer_handler", None) + ): + # this is probably is a LightningModule + input_array = model._on_before_batch_transfer(input_array) # type: ignore[operator] + input_array = model._apply_batch_transfer_handler(input_array) # type: ignore[operator] + self.experiment.add_graph(model, input_array) + + @rank_zero_only + def save(self) -> None: + self.experiment.flush() + + @rank_zero_only + def finalize(self, status: str) -> None: + if self._experiment is not None: + self.experiment.flush() + self.experiment.close() + + def _get_next_version(self) -> int: + save_dir = os.path.join(self.root_dir, self.name) + + try: + listdir_info = self._fs.listdir(save_dir) + except OSError: + # TODO(fabric): This message can be confusing (did user do something wrong?). Improve it or remove it. + log.warning("Missing logger folder: %s", save_dir) + return 0 + + existing_versions = [] + for listing in listdir_info: + d = listing["name"] + bn = os.path.basename(d) + if self._fs.isdir(d) and bn.startswith("version_"): + dir_ver = bn.split("_")[1].replace("/", "") + existing_versions.append(int(dir_ver)) + if len(existing_versions) == 0: + return 0 + + return max(existing_versions) + 1 + + @staticmethod + def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: + params = _utils_sanitize_params(params) + # logging of arrays with dimension > 1 is not supported, sanitize as string + return {k: str(v) if isinstance(v, (Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()} + + def __getstate__(self) -> Dict[str, Any]: + state = self.__dict__.copy() + state["_experiment"] = None + return state diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index cf86a4ccb756c..aa71bf2fab120 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -33,6 +33,7 @@ import lightning_fabric as lf import pytorch_lightning as pl +from lightning_fabric.loggers import Logger as FabricLogger from lightning_fabric.utilities.apply_func import convert_to_tensors from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin @@ -291,14 +292,20 @@ def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None: self._truncated_bptt_steps = truncated_bptt_steps @property - def logger(self) -> Optional[Logger]: + def logger(self) -> Optional[Union[Logger, FabricLogger]]: """Reference to the logger object in the Trainer.""" + if self._fabric is not None: + return self._fabric.logger return self._trainer.logger if self._trainer is not None else None @property - def loggers(self) -> List[Logger]: + def loggers(self) -> Union[List[Logger], List[FabricLogger]]: """Reference to the list of loggers in the Trainer.""" - return self.trainer.loggers if self._trainer else [] + if self._fabric is not None: + return self._fabric.loggers + elif self._trainer is not None: + return self._trainer.loggers + return [] # type: ignore[return-value] def _call_batch_hook(self, hook_name: str, *args: Any) -> Any: if self._trainer: diff --git a/src/pytorch_lightning/loggers/comet.py b/src/pytorch_lightning/loggers/comet.py index e66281a420fca..7f38545955fa6 100644 --- a/src/pytorch_lightning/loggers/comet.py +++ b/src/pytorch_lightning/loggers/comet.py @@ -23,8 +23,8 @@ from lightning_utilities.core.imports import module_available from torch import Tensor +from torch.nn import Module -import pytorch_lightning as pl from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -423,6 +423,6 @@ def __getstate__(self) -> Dict[str, Any]: state["_experiment"] = None return state - def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None: + def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: if self._experiment is not None: self._experiment.set_model_graph(model) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 30bd9e66261e9..3e96e20952b28 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -16,46 +16,19 @@ import functools import operator -from abc import ABC, abstractmethod -from argparse import Namespace +from abc import ABC from collections import defaultdict -from functools import wraps -from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Mapping, Optional, Sequence import numpy as np -from torch import Tensor -import pytorch_lightning as pl +from lightning_fabric.loggers import Logger as FabricLogger +from lightning_fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility +from lightning_fabric.loggers.logger import rank_zero_experiment # noqa: F401 # for backward compatibility from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.utilities.rank_zero import rank_zero_only -def rank_zero_experiment(fn: Callable) -> Callable: - """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" - - @wraps(fn) - def experiment(self) -> Union[Any, DummyExperiment]: # type: ignore[no-untyped-def] - """ - Note: - ``self`` is a custom logger instance. The loggers typically wrap an ``experiment`` method - with a ``@rank_zero_experiment`` decorator. An exception is that ``loggers.neptune`` wraps - ``experiment`` and ``run`` with rank_zero_experiment. - - ``Union[Any, DummyExperiment]`` is used because the wrapped hooks have several return - types that are specific to the custom logger. The return type here can be considered as - ``Union[return type of logger.experiment, DummyExperiment]``. - """ - - @rank_zero_only - def get_experiment() -> Callable: - return fn(self) - - return get_experiment() or DummyExperiment() - - return experiment - - -class Logger(ABC): +class Logger(FabricLogger, ABC): """Base class for experiment loggers.""" def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: @@ -66,84 +39,12 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: """ pass - @abstractmethod - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: - """Records metrics. This method logs metrics as soon as it received them. - - Args: - metrics: Dictionary with metric names as keys and measured quantities as values - step: Step number at which the metrics should be recorded - """ - pass - - @abstractmethod - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: - """Record hyperparameters. - - Args: - params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters - args: Optional positional arguments, depends on the specific logger being used - kwargs: Optional keyword arguments, depends on the specific logger being used - """ - - def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None: - """Record model graph. - - Args: - model: lightning model - input_array: input passes to `model.forward` - """ - pass - - def save(self) -> None: - """Save log data.""" - - def finalize(self, status: str) -> None: - """Do any processing that is necessary to finalize an experiment. - - Args: - status: Status that the experiment finished with (e.g. success, failed, aborted) - """ - self.save() - @property def save_dir(self) -> Optional[str]: """Return the root directory where experiment logs get saved, or `None` if the logger does not save data locally.""" return None - @property - def group_separator(self) -> str: - """Return the default separator used by the logger to group the data into subfolders.""" - return "/" - - @property - @abstractmethod - def name(self) -> Optional[str]: - """Return the experiment name.""" - - @property - @abstractmethod - def version(self) -> Optional[Union[int, str]]: - """Return the experiment version.""" - - -class DummyExperiment: - """Dummy experiment.""" - - def nop(self, *args: Any, **kw: Any) -> None: - pass - - def __getattr__(self, _: Any) -> Callable: - return self.nop - - def __getitem__(self, idx: int) -> "DummyExperiment": - # enables self.logger.experiment[0].add_image(...) - return self - - def __setitem__(self, *args: Any, **kwargs: Any) -> None: - pass - class DummyLogger(Logger): """Dummy logger for internal use. diff --git a/src/pytorch_lightning/loggers/tensorboard.py b/src/pytorch_lightning/loggers/tensorboard.py index a2553576601e0..6dd8571df3690 100644 --- a/src/pytorch_lightning/loggers/tensorboard.py +++ b/src/pytorch_lightning/loggers/tensorboard.py @@ -19,37 +19,28 @@ import logging import os from argparse import Namespace -from typing import Any, Dict, Mapping, Optional, TYPE_CHECKING, Union +from typing import Any, Dict, Optional, Union -import numpy as np -from lightning_utilities.core.imports import RequirementCache from torch import Tensor import pytorch_lightning as pl -from lightning_fabric.utilities.cloud_io import get_filesystem -from lightning_fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict -from lightning_fabric.utilities.logger import _sanitize_params as _utils_sanitize_params +from lightning_fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE +from lightning_fabric.loggers.tensorboard import TensorBoardLogger as FabricTensorBoardLogger +from lightning_fabric.utilities.logger import _convert_params from lightning_fabric.utilities.types import _PATH +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.saving import save_hparams_to_yaml -from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment +from pytorch_lightning.loggers.logger import Logger from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn log = logging.getLogger(__name__) -_TENSORBOARD_AVAILABLE = RequirementCache("tensorboard") -if TYPE_CHECKING: - # assumes at least one will be installed when type checking - if _TENSORBOARD_AVAILABLE: - from torch.utils.tensorboard import SummaryWriter - else: - from tensorboardX import SummaryWriter # type: ignore[no-redef] - if _OMEGACONF_AVAILABLE: from omegaconf import Container, OmegaConf -class TensorBoardLogger(Logger): +class TensorBoardLogger(Logger, FabricTensorBoardLogger): r""" Log to local file system in `TensorBoard `_ format. @@ -100,7 +91,6 @@ class TensorBoardLogger(Logger): >>> shutil.rmtree(tmp) """ NAME_HPARAMS_FILE = "hparams.yaml" - LOGGER_JOIN_CHAR = "-" def __init__( self, @@ -113,23 +103,19 @@ def __init__( sub_dir: Optional[_PATH] = None, **kwargs: Any, ): - super().__init__() - save_dir = os.fspath(save_dir) - self._save_dir = save_dir - self._name = name or "" - self._version = version - self._sub_dir = None if sub_dir is None else os.fspath(sub_dir) + super().__init__( + root_dir=save_dir, + name=name, + version=version, + default_hp_metric=default_hp_metric, + prefix=prefix, + sub_dir=sub_dir, + **kwargs, + ) if log_graph and not _TENSORBOARD_AVAILABLE: rank_zero_warn("You set `TensorBoardLogger(log_graph=True)` but `tensorboard` is not available.") self._log_graph = log_graph and _TENSORBOARD_AVAILABLE - - self._default_hp_metric = default_hp_metric - self._prefix = prefix - self._fs = get_filesystem(save_dir) - - self._experiment: Optional["SummaryWriter"] = None self.hparams: Union[Dict[str, Any], Namespace] = {} - self._kwargs = kwargs @property def root_dir(self) -> str: @@ -138,7 +124,7 @@ def root_dir(self) -> str: If the experiment name parameter is an empty string, no experiment subdirectory is used and the checkpoint will be saved in "save_dir/version" """ - return os.path.join(self.save_dir, self.name) + return os.path.join(super().root_dir, self.name) @property def log_dir(self) -> str: @@ -163,43 +149,7 @@ def save_dir(self) -> str: Returns: The local path to the save directory where the TensorBoard experiments are saved. """ - return self._save_dir - - @property - def sub_dir(self) -> Optional[str]: - """Gets the sub directory where the TensorBoard experiments are saved. - - Returns: - The local path to the sub directory where the TensorBoard experiments are saved. - """ - return self._sub_dir - - @property - @rank_zero_experiment - def experiment(self) -> "SummaryWriter": - r""" - Actual tensorboard object. To use TensorBoard features in your - :class:`~pytorch_lightning.core.module.LightningModule` do the following. - - Example:: - - self.logger.experiment.some_tensorboard_function() - - """ - if self._experiment is not None: - return self._experiment - - assert rank_zero_only.rank == 0, "tried to init log dirs in non global_rank=0" - if self.root_dir: - self._fs.makedirs(self.root_dir, exist_ok=True) - - if _TENSORBOARD_AVAILABLE: - from torch.utils.tensorboard import SummaryWriter - else: - from tensorboardX import SummaryWriter # type: ignore[no-redef] - - self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs) - return self._experiment + return self._root_dir @rank_zero_only def log_hyperparams( @@ -213,7 +163,6 @@ def log_hyperparams( params: a dictionary-like container with the hyperparameters metrics: Dictionary with metric names as keys and measured quantities as values """ - params = _convert_params(params) # store params to output @@ -222,49 +171,7 @@ def log_hyperparams( else: self.hparams.update(params) - # format params into the suitable for tensorboard - params = _flatten_dict(params) - params = self._sanitize_params(params) - - if metrics is None: - if self._default_hp_metric: - metrics = {"hp_metric": -1} - elif not isinstance(metrics, dict): - metrics = {"hp_metric": metrics} - - if metrics: - self.log_metrics(metrics, 0) - - if _TENSORBOARD_AVAILABLE: - from torch.utils.tensorboard.summary import hparams - else: - from tensorboardX.summary import hparams # type: ignore[no-redef] - - exp, ssi, sei = hparams(params, metrics) - writer = self.experiment._get_file_writer() - writer.add_summary(exp) - writer.add_summary(ssi) - writer.add_summary(sei) - - @rank_zero_only - def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: - assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - - metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - - for k, v in metrics.items(): - if isinstance(v, Tensor): - v = v.item() - - if isinstance(v, dict): - self.experiment.add_scalars(k, v, step) - else: - try: - self.experiment.add_scalar(k, v, step) - # todo: specify the possible exception - except Exception as ex: - m = f"\n you tried to log {v} which is currently not supported. Try a dict or a scalar/tensor." - raise ValueError(m) from ex + return super().log_hyperparams(params=params, metrics=metrics) @rank_zero_only def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None: @@ -304,33 +211,18 @@ def save(self) -> None: @rank_zero_only def finalize(self, status: str) -> None: - if self._experiment is not None: - self.experiment.flush() - self.experiment.close() - + super().finalize(status) if status == "success": # saving hparams happens independent of experiment manager self.save() - @property - def name(self) -> str: - """Get the name of the experiment. + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + """Called after model checkpoint callback saves a new checkpoint. - Returns: - The name of the experiment. - """ - return self._name - - @property - def version(self) -> Union[int, str]: - """Get the experiment version. - - Returns: - The experiment version if specified else the next version. + Args: + checkpoint_callback: the model checkpoint callback instance """ - if self._version is None: - self._version = self._get_next_version() - return self._version + pass def _get_next_version(self) -> int: root_dir = self.root_dir @@ -352,14 +244,3 @@ def _get_next_version(self) -> int: return 0 return max(existing_versions) + 1 - - @staticmethod - def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: - params = _utils_sanitize_params(params) - # logging of arrays with dimension > 1 is not supported, sanitize as string - return {k: str(v) if isinstance(v, (Tensor, np.ndarray)) and v.ndim > 1 else v for k, v in params.items()} - - def __getstate__(self) -> Dict[str, Any]: - state = self.__dict__.copy() - state["_experiment"] = None - return state diff --git a/tests/tests_fabric/loggers/__init__.py b/tests/tests_fabric/loggers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py new file mode 100644 index 0000000000000..a0569248be143 --- /dev/null +++ b/tests/tests_fabric/loggers/test_tensorboard.py @@ -0,0 +1,229 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from argparse import Namespace +from unittest import mock +from unittest.mock import Mock + +import numpy as np +import pytest +import torch +from tests_fabric.test_fabric import BoringModel + +from lightning_fabric.loggers import TensorBoardLogger +from lightning_fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE + + +def test_tensorboard_automatic_versioning(tmpdir): + """Verify that automatic versioning works.""" + root_dir = tmpdir / "tb_versioning" + root_dir.mkdir() + (root_dir / "version_0").mkdir() + (root_dir / "version_1").mkdir() + + logger = TensorBoardLogger(root_dir=tmpdir, name="tb_versioning") + assert logger.version == 2 + + +def test_tensorboard_manual_versioning(tmpdir): + """Verify that manual versioning works.""" + root_dir = tmpdir / "tb_versioning" + root_dir.mkdir() + (root_dir / "version_0").mkdir() + (root_dir / "version_1").mkdir() + (root_dir / "version_2").mkdir() + + logger = TensorBoardLogger(root_dir=tmpdir, name="tb_versioning", version=1) + assert logger.version == 1 + + +def test_tensorboard_named_version(tmpdir): + """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'.""" + name = "tb_versioning" + (tmpdir / name).mkdir() + expected_version = "2020-02-05-162402" + + logger = TensorBoardLogger(root_dir=tmpdir, name=name, version=expected_version) + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written + + assert logger.version == expected_version + assert os.listdir(tmpdir / name) == [expected_version] + assert os.listdir(tmpdir / name / expected_version) + + +@pytest.mark.parametrize("name", ["", None]) +def test_tensorboard_no_name(tmpdir, name): + """Verify that None or empty name works.""" + logger = TensorBoardLogger(root_dir=tmpdir, name=name) + logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written + assert os.path.normpath(logger.root_dir) == tmpdir # use os.path.normpath to handle trailing / + assert os.listdir(tmpdir / "version_0") + + +def test_tensorboard_log_sub_dir(tmpdir): + # no sub_dir specified + root_dir = tmpdir / "logs" + logger = TensorBoardLogger(root_dir, name="name", version="version") + assert logger.log_dir == os.path.join(root_dir, "name", "version") + + # sub_dir specified + logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.log_dir == os.path.join(root_dir, "name", "version", "sub_dir") + + +def test_tensorboard_expand_home(): + """Test that the home dir (`~`) gets expanded properly.""" + root_dir = "~/tmp" + explicit_root_dir = os.path.expanduser(root_dir) + logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.root_dir == root_dir + assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir") + + +@mock.patch.dict(os.environ, {"TEST_ENV_DIR": "some_directory"}) +def test_tensorboard_expand_env_vars(): + """Test that the env vars in path names (`$`) get handled properly.""" + test_env_dir = os.environ["TEST_ENV_DIR"] + root_dir = "$TEST_ENV_DIR/tmp" + explicit_root_dir = f"{test_env_dir}/tmp" + logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir") + assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir") + + +@pytest.mark.parametrize("step_idx", [10, None]) +def test_tensorboard_log_metrics(tmpdir, step_idx): + logger = TensorBoardLogger(tmpdir) + metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)} + logger.log_metrics(metrics, step_idx) + + +def test_tensorboard_log_hyperparams(tmpdir): + logger = TensorBoardLogger(tmpdir) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": {"a": {"b": "c"}}, + "list": [1, 2, 3], + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, + "tensor": torch.empty(2, 2, 2), + "array": np.empty([2, 2, 2]), + } + logger.log_hyperparams(hparams) + + +def test_tensorboard_log_hparams_and_metrics(tmpdir): + logger = TensorBoardLogger(tmpdir, default_hp_metric=False) + hparams = { + "float": 0.3, + "int": 1, + "string": "abc", + "bool": True, + "dict": {"a": {"b": "c"}}, + "list": [1, 2, 3], + "namespace": Namespace(foo=Namespace(bar="buzz")), + "layer": torch.nn.BatchNorm1d, + "tensor": torch.empty(2, 2, 2), + "array": np.empty([2, 2, 2]), + } + metrics = {"abc": torch.tensor([0.54])} + logger.log_hyperparams(hparams, metrics) + + +@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)]) +def test_tensorboard_log_graph(tmpdir, example_input_array): + """test that log graph works with both model.example_input_array and if array is passed externally.""" + # TODO(fabric): Test both nn.Module and LightningModule + # TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks + model = BoringModel() + if example_input_array is not None: + model.example_input_array = None + + logger = TensorBoardLogger(tmpdir, log_graph=True) + logger.log_graph(model, example_input_array) + + +@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE)) +def test_tensorboard_log_graph_warning_no_example_input_array(tmpdir): + """test that log graph throws warning if model.example_input_array is None.""" + model = BoringModel() + model.example_input_array = None + logger = TensorBoardLogger(tmpdir, log_graph=True) + with pytest.warns( + UserWarning, + match="Could not log computational graph to TensorBoard: The `model.example_input_array` .* was not given", + ): + logger.log_graph(model) + + model.example_input_array = dict(x=1, y=2) + with pytest.warns( + UserWarning, match="Could not log computational graph to TensorBoard: .* can't be traced by TensorBoard" + ): + logger.log_graph(model) + + +def test_tensorboard_finalize(monkeypatch, tmpdir): + """Test that the SummaryWriter closes in finalize.""" + if _TENSORBOARD_AVAILABLE: + import torch.utils.tensorboard as tb + else: + import tensorboardX as tb + + monkeypatch.setattr(tb, "SummaryWriter", Mock()) + logger = TensorBoardLogger(root_dir=tmpdir) + assert logger._experiment is None + logger.finalize("any") + + # no log calls, no experiment created -> nothing to flush + logger.experiment.assert_not_called() + + logger = TensorBoardLogger(root_dir=tmpdir) + logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment + logger.finalize("any") + + # finalize flushes to experiment directory + logger.experiment.flush.assert_called() + logger.experiment.close.assert_called() + + +@mock.patch("lightning_fabric.loggers.tensorboard.log") +def test_tensorboard_with_symlink(log, tmpdir): + """Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, + and relative paths.""" + os.chdir(tmpdir) # need to use relative paths + source = os.path.join(".", "lightning_logs") + dest = os.path.join(".", "sym_lightning_logs") + + os.makedirs(source, exist_ok=True) + os.symlink(source, dest) + + logger = TensorBoardLogger(root_dir=dest, name="") + _ = logger.version + + log.warning.assert_not_called() + + +def test_tensorboard_missing_folder_warning(tmpdir, caplog): + """Verify that the logger throws a warning for invalid directory.""" + + name = "fake_dir" + logger = TensorBoardLogger(root_dir=tmpdir, name=name) + + with caplog.at_level(logging.WARNING): + assert logger.version == 0 + + assert "Missing logger folder:" in caplog.text diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 8aa40b0359a3f..88ac3d9fb7d82 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -718,7 +718,7 @@ def test_callbacks_input(): assert fabric._callbacks == [callback0, callback1] -def test_call_callbacks(): +def test_call(): """Test that `fabric.call` triggers the callback implementations.""" callback0 = Mock() callback1 = Mock() @@ -748,3 +748,55 @@ def test_call_callbacks(): with pytest.warns(UserWarning, match="Skipping the callback `Mock.not_a_method`"): fabric.call("not_a_method") assert not callback1.mock_calls + + +def test_loggers_input(): + """Test the various ways in which loggers can be registered with Fabric.""" + logger0 = Mock() + logger1 = Mock() + + # no logger + fabric = Fabric(loggers=None) + assert fabric._loggers == [] + fabric = Fabric(loggers=[]) + assert fabric._loggers == [] + + # single logger + fabric = Fabric(loggers=logger0) + assert fabric._loggers == [logger0] + + # multiple loggers + fabric = Fabric(loggers=[logger0, logger1]) + assert fabric._loggers == [logger0, logger1] + + +def test_log(): + """Test that `fabric.log` sends the metrics to each logger.""" + + logger0 = Mock() + logger1 = Mock() + fabric = Fabric(loggers=[logger0, logger1]) + + fabric.log("test", 1) + logger0.log_metrics.assert_called_with(metrics={"test": 1}, step=None) + logger1.log_metrics.assert_called_with(metrics={"test": 1}, step=None) + + fabric.log("test", 2, step=15) + logger0.log_metrics.assert_called_with(metrics={"test": 2}, step=15) + logger1.log_metrics.assert_called_with(metrics={"test": 2}, step=15) + + +def test_log_dict(): + """Test that `fabric.log_dict` sends the metrics dict to each logger.""" + + logger0 = Mock() + logger1 = Mock() + fabric = Fabric(loggers=[logger0, logger1]) + + fabric.log_dict({"foo": 1, "bar": 2}, step=None) + logger0.log_metrics.assert_called_with(metrics={"foo": 1, "bar": 2}, step=None) + logger1.log_metrics.assert_called_with(metrics={"foo": 1, "bar": 2}, step=None) + + fabric.log_dict({"foo": 3, "bar": 4}, step=15) + logger0.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15) + logger1.log_metrics.assert_called_with(metrics={"foo": 3, "bar": 4}, step=15)