Skip to content

Commit d46091c

Browse files
Bordacarmocca
andauthored
fix dirpath in log_dir for CSVLogger (#16401)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent cfe87a0 commit d46091c

File tree

4 files changed

+14
-13
lines changed

4 files changed

+14
-13
lines changed

.github/workflows/ci-tests-pytorch.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,10 @@ jobs:
218218
flags: ${COVERAGE_SCOPE},cpu,pytest-full,python${{ matrix.python-version }},pytorch${{ matrix.pytorch-version }}
219219
name: CPU-coverage
220220
fail_ci_if_error: false
221+
222+
# TODO
223+
# - name: Testing legacy creation
224+
# working-directory: tests/
225+
# run: |
226+
# export PYTHONPATH=$(dirname $LEGACY_PATH);$PYTHONPATH # for `import tests_pytorch`
227+
# python legacy/simple_classif_training.py

src/pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -582,14 +582,11 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
582582
return self.dirpath
583583

584584
if len(trainer.loggers) > 0:
585-
if trainer.loggers[0].save_dir is not None:
586-
save_dir = trainer.loggers[0].save_dir
587-
else:
588-
save_dir = trainer.default_root_dir
589-
name = trainer.loggers[0].name
590-
version = trainer.loggers[0].version
585+
logger_ = trainer.loggers[0]
586+
save_dir = getattr(logger_, "save_dir", None) or trainer.default_root_dir
587+
version = logger_.version
591588
version = version if isinstance(version, str) else f"version_{version}"
592-
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
589+
ckpt_path = os.path.join(save_dir, str(logger_.name), version, "checkpoints")
593590
else:
594591
# if no loggers, use default_root_dir
595592
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

src/pytorch_lightning/loggers/csv_logs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class ExperimentWriter(_FabricExperimentWriter):
3939
r"""
4040
Experiment writer for CSVLogger.
4141
42-
Currently supports to log hyperparameters and metrics in YAML and CSV
42+
Currently, supports to log hyperparameters and metrics in YAML and CSV
4343
format, respectively.
4444
4545
Args:

src/pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
5454
from pytorch_lightning.core.datamodule import LightningDataModule
5555
from pytorch_lightning.loggers import Logger
56-
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
5756
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
5857
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
5958
from pytorch_lightning.loops.fit_loop import FitLoop
@@ -1807,10 +1806,8 @@ def model(self, model: torch.nn.Module) -> None:
18071806
@property
18081807
def log_dir(self) -> Optional[str]:
18091808
if len(self.loggers) > 0:
1810-
if not isinstance(self.loggers[0], TensorBoardLogger):
1811-
dirpath = self.loggers[0].save_dir
1812-
else:
1813-
dirpath = self.loggers[0].log_dir
1809+
logger_ = self.loggers[0]
1810+
dirpath = getattr(logger_, "log_dir", None) or getattr(logger_, "save_dir", None)
18141811
else:
18151812
dirpath = self.default_root_dir
18161813

0 commit comments

Comments
 (0)