|
34 | 34 | from torch import Tensor
|
35 | 35 |
|
36 | 36 | import lightning.pytorch as pl
|
37 |
| -from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem |
| 37 | +from lightning.fabric.utilities.cloud_io import _is_dir, _is_local_file_protocol, get_filesystem |
38 | 38 | from lightning.fabric.utilities.types import _PATH
|
39 | 39 | from lightning.pytorch.callbacks import Checkpoint
|
40 | 40 | from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
@@ -457,7 +457,7 @@ def __validate_init_configuration(self) -> None:
|
457 | 457 | def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
|
458 | 458 | self._fs = get_filesystem(dirpath if dirpath else "")
|
459 | 459 |
|
460 |
| - if dirpath and self._fs.protocol == "file": |
| 460 | + if dirpath and _is_local_file_protocol(dirpath if dirpath else ""): |
461 | 461 | dirpath = os.path.realpath(dirpath)
|
462 | 462 |
|
463 | 463 | self.dirpath = dirpath
|
@@ -675,7 +675,7 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
|
675 | 675 |
|
676 | 676 | # set the last model path before saving because it will be part of the state.
|
677 | 677 | previous, self.last_model_path = self.last_model_path, filepath
|
678 |
| - if self._fs.protocol == "file" and self._last_checkpoint_saved and self.save_top_k != 0: |
| 678 | + if _is_local_file_protocol(filepath) and self._last_checkpoint_saved and self.save_top_k != 0: |
679 | 679 | self._link_checkpoint(trainer, self._last_checkpoint_saved, filepath)
|
680 | 680 | else:
|
681 | 681 | self._save_checkpoint(trainer, filepath)
|
@@ -771,7 +771,7 @@ def _should_remove_checkpoint(self, trainer: "pl.Trainer", previous: str, curren
|
771 | 771 | """
|
772 | 772 | if previous == current:
|
773 | 773 | return False
|
774 |
| - if self._fs.protocol != "file": |
| 774 | + if not _is_local_file_protocol(previous): |
775 | 775 | return True
|
776 | 776 | previous = Path(previous).absolute()
|
777 | 777 | resume_path = Path(trainer.ckpt_path).absolute() if trainer.ckpt_path is not None else None
|
|
0 commit comments