Skip to content

Commit 2e90ddf

Browse files
committed
Fix logger typing by using fabric's base logger
1 parent 544aeae commit 2e90ddf

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

src/pytorch_lightning/core/module.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import lightning_fabric as lf
3535
import pytorch_lightning as pl
36+
from lightning_fabric.loggers import Logger
3637
from lightning_fabric.utilities.apply_func import convert_to_tensors
3738
from lightning_fabric.utilities.cloud_io import get_filesystem
3839
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
@@ -45,7 +46,6 @@
4546
from pytorch_lightning.core.mixins import HyperparametersMixin
4647
from pytorch_lightning.core.optimizer import LightningOptimizer
4748
from pytorch_lightning.core.saving import ModelIO
48-
from pytorch_lightning.loggers import Logger, TensorBoardLogger
4949
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
5050
from pytorch_lightning.utilities import GradClipAlgorithmType
5151
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -291,12 +291,12 @@ def truncated_bptt_steps(self, truncated_bptt_steps: int) -> None:
291291
self._truncated_bptt_steps = truncated_bptt_steps
292292

293293
@property
294-
def logger(self) -> Optional[Union[Logger, TensorBoardLogger]]:
294+
def logger(self) -> Optional[Logger]:
295295
"""Reference to the logger object in the Trainer."""
296296
return self._trainer.logger if self._trainer is not None else None
297297

298298
@property
299-
def loggers(self) -> List[Union[Logger, TensorBoardLogger]]:
299+
def loggers(self) -> List[Logger]:
300300
"""Reference to the list of loggers in the Trainer."""
301301
return self.trainer.loggers if self._trainer else []
302302

src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from torch import Tensor
1818

1919
import pytorch_lightning as pl
20+
from lightning_fabric.loggers import Logger
2021
from lightning_fabric.plugins.environments import SLURMEnvironment
2122
from lightning_fabric.utilities import move_data_to_device
22-
from pytorch_lightning.loggers import Logger, TensorBoardLogger
23+
from pytorch_lightning.loggers import TensorBoardLogger
2324
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
2425
from pytorch_lightning.utilities.metrics import metrics_to_scalars
2526

@@ -37,7 +38,7 @@ def __init__(self, trainer: "pl.Trainer") -> None:
3738

3839
def on_trainer_init(
3940
self,
40-
logger: Union[bool, Logger, TensorBoardLogger, Iterable[Union[Logger, TensorBoardLogger]]],
41+
logger: Union[bool, Logger, Iterable[Logger]],
4142
log_every_n_steps: int,
4243
move_metrics_to_cpu: bool,
4344
) -> None:
@@ -51,9 +52,7 @@ def should_update_logs(self) -> bool:
5152
should_log = (self.trainer.fit_loop.epoch_loop._batches_that_stepped + 1) % self.trainer.log_every_n_steps == 0
5253
return should_log or self.trainer.should_stop
5354

54-
def configure_logger(
55-
self, logger: Union[bool, Logger, TensorBoardLogger, Iterable[Union[Logger, TensorBoardLogger]]]
56-
) -> None:
55+
def configure_logger(self, logger: Union[bool, Logger, Iterable[Logger]]) -> None:
5756
if not logger:
5857
# logger is None or logger is False
5958
self.trainer.loggers = []

src/pytorch_lightning/trainer/trainer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from typing_extensions import Literal
4444

4545
import pytorch_lightning as pl
46+
from lightning_fabric.loggers import Logger
4647
from lightning_fabric.utilities.cloud_io import get_filesystem
4748
from lightning_fabric.utilities.data import _auto_add_worker_init_fn
4849
from lightning_fabric.utilities.types import _PATH
@@ -51,7 +52,6 @@
5152
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
5253
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
5354
from pytorch_lightning.core.datamodule import LightningDataModule
54-
from pytorch_lightning.loggers import Logger
5555
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
5656
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
5757
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
@@ -112,7 +112,7 @@ class Trainer:
112112
@_defaults_from_env_vars
113113
def __init__(
114114
self,
115-
logger: Union[Union[Logger, TensorBoardLogger], Iterable[Union[Logger, TensorBoardLogger]], bool] = True,
115+
logger: Union[Union[Logger], Iterable[Union[Logger]], bool] = True,
116116
enable_checkpointing: bool = True,
117117
callbacks: Optional[Union[List[Callback], Callback]] = None,
118118
default_root_dir: Optional[_PATH] = None,
@@ -448,7 +448,7 @@ def __init__(
448448
setup._init_profiler(self, profiler)
449449

450450
# init logger flags
451-
self._loggers: List[Union[Logger, TensorBoardLogger]]
451+
self._loggers: List[Logger]
452452
self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)
453453

454454
# init debugging flags
@@ -2068,11 +2068,11 @@ def logger(self, logger: Optional[Logger]) -> None:
20682068
self.loggers = [logger]
20692069

20702070
@property
2071-
def loggers(self) -> List[Union[Logger, TensorBoardLogger]]:
2071+
def loggers(self) -> List[Union[Logger]]:
20722072
return self._loggers
20732073

20742074
@loggers.setter
2075-
def loggers(self, loggers: Optional[List[Union[Logger, TensorBoardLogger]]]) -> None:
2075+
def loggers(self, loggers: Optional[List[Logger]]) -> None:
20762076
self._loggers = loggers if loggers else []
20772077

20782078
@property

0 commit comments

Comments
 (0)