Skip to content

Commit b97cf4b

Browse files
authored
Mark the connectors as protected (#17008)
1 parent bb960b8 commit b97cf4b

16 files changed

+67
-63
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
103103

104104
- Marked `lightning.pytorch.utilities.supporters.CombinedDataset` as protected ([#16714](https://github.com/Lightning-AI/lightning/pull/16714))
105105

106+
- Marked the `{Accelerator,Signal,Callback,Checkpoint,Data,Logger}Connector` classes as protected ([#17008](https://github.com/Lightning-AI/lightning/pull/17008))
107+
108+
- Marked the `lightning.pytorch.trainer.connectors.signal_connector.HandlersCompose` class as protected ([#17008](https://github.com/Lightning-AI/lightning/pull/17008))
109+
106110
- Disabled strict loading in multiprocessing launcher ("ddp_spawn", etc.) when loading weights back into the main process ([#16365](https://github.com/Lightning-AI/lightning/pull/16365))
107111

108112

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
_LITERAL_WARN = Literal["warn"]
7777

7878

79-
class AcceleratorConnector:
79+
class _AcceleratorConnector:
8080
def __init__(
8181
self,
8282
devices: Union[List[int], str, int] = "auto",

src/lightning/pytorch/trainer/connectors/callback_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
_log = logging.getLogger(__name__)
4141

4242

43-
class CallbackConnector:
43+
class _CallbackConnector:
4444
def __init__(self, trainer: "pl.Trainer"):
4545
self.trainer = trainer
4646

@@ -181,7 +181,7 @@ def _attach_model_callbacks(self) -> None:
181181
# remove all callbacks with a type that occurs in model callbacks
182182
all_callbacks = [c for c in trainer.callbacks if type(c) not in override_types]
183183
all_callbacks.extend(model_callbacks)
184-
all_callbacks = CallbackConnector._reorder_callbacks(all_callbacks)
184+
all_callbacks = _CallbackConnector._reorder_callbacks(all_callbacks)
185185
# TODO: connectors refactor: move callbacks list to connector and do not write Trainer state
186186
trainer.callbacks = all_callbacks
187187

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
log: logging.Logger = logging.getLogger(__name__)
4343

4444

45-
class CheckpointConnector:
45+
class _CheckpointConnector:
4646
def __init__(self, trainer: "pl.Trainer") -> None:
4747
self.trainer = trainer
4848
self._ckpt_path: Optional[_PATH] = None
@@ -221,7 +221,7 @@ def resume_end(self) -> None:
221221
torch.cuda.empty_cache()
222222

223223
# wait for all to catch up
224-
self.trainer.strategy.barrier("CheckpointConnector.resume_end")
224+
self.trainer.strategy.barrier("_CheckpointConnector.resume_end")
225225

226226
def restore(self, checkpoint_path: Optional[_PATH] = None) -> None:
227227
"""Attempt to restore everything at once from a 'PyTorch-Lightning checkpoint' file through file-read and
@@ -545,13 +545,13 @@ def __max_ckpt_version_in_folder(dir_path: _PATH, name_key: str = "ckpt_") -> Op
545545
def __get_max_ckpt_path_from_folder(folder_path: _PATH) -> str:
546546
"""Get path of maximum-epoch checkpoint in the folder."""
547547

548-
max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folder_path)
548+
max_suffix = _CheckpointConnector.__max_ckpt_version_in_folder(folder_path)
549549
ckpt_number = max_suffix if max_suffix is not None else 0
550550
return f"{folder_path}/hpc_ckpt_{ckpt_number}.ckpt"
551551

552552
@staticmethod
553553
def hpc_save_path(folderpath: _PATH) -> str:
554-
max_suffix = CheckpointConnector.__max_ckpt_version_in_folder(folderpath)
554+
max_suffix = _CheckpointConnector.__max_ckpt_version_in_folder(folderpath)
555555
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
556556
filepath = os.path.join(folderpath, f"hpc_ckpt_{ckpt_number}.ckpt")
557557
return filepath

src/lightning/pytorch/trainer/connectors/data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
warning_cache = WarningCache()
4444

4545

46-
class DataConnector:
46+
class _DataConnector:
4747
def __init__(self, trainer: "pl.Trainer"):
4848
self.trainer = trainer
4949
self._datahook_selector: Optional[_DataHookSelector] = None
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from lightning.pytorch.trainer.connectors.logger_connector.logger_connector import LoggerConnector # noqa: F401
1+
from lightning.pytorch.trainer.connectors.logger_connector.logger_connector import _LoggerConnector # noqa: F401

src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
warning_cache = WarningCache()
2929

3030

31-
class LoggerConnector:
31+
class _LoggerConnector:
3232
def __init__(self, trainer: "pl.Trainer") -> None:
3333
self.trainer = trainer
3434
self._progress_bar_metrics: _PBAR_DICT = {}

src/lightning/pytorch/trainer/connectors/signal_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
log = logging.getLogger(__name__)
2323

2424

25-
class HandlersCompose:
25+
class _HandlersCompose:
2626
def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None:
2727
if not isinstance(signal_handlers, list):
2828
signal_handlers = [signal_handlers]
@@ -36,7 +36,7 @@ def __call__(self, signum: _SIGNUM, frame: FrameType) -> None:
3636
signal_handler(signum, frame)
3737

3838

39-
class SignalConnector:
39+
class _SignalConnector:
4040
def __init__(self, trainer: "pl.Trainer") -> None:
4141
self.received_sigterm = False
4242
self.trainer = trainer
@@ -60,12 +60,12 @@ def register_signal_handlers(self) -> None:
6060
sigusr = environment.requeue_signal if isinstance(environment, SLURMEnvironment) else signal.SIGUSR1
6161
assert sigusr is not None
6262
if sigusr_handlers and not self._has_already_handler(sigusr):
63-
self._register_signal(sigusr, HandlersCompose(sigusr_handlers))
63+
self._register_signal(sigusr, _HandlersCompose(sigusr_handlers))
6464

6565
# we have our own handler, but include existing ones too
6666
if self._has_already_handler(signal.SIGTERM):
6767
sigterm_handlers.append(signal.getsignal(signal.SIGTERM))
68-
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
68+
self._register_signal(signal.SIGTERM, _HandlersCompose(sigterm_handlers))
6969

7070
def _slurm_sigusr_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
7171
rank_zero_info(f"Handling auto-requeue signal: {signum}")
@@ -119,7 +119,7 @@ def _sigterm_handler_fn(self, signum: _SIGNUM, _: FrameType) -> None:
119119
log.info(f"Bypassing SIGTERM: {signum}")
120120

121121
def teardown(self) -> None:
122-
"""Restores the signals that were previously configured before :class:`SignalConnector` replaced them."""
122+
"""Restores the signals that were previously configured before :class:`_SignalConnector` replaced them."""
123123
for signum, handler in self._original_handlers.items():
124124
if handler is not None:
125125
self._register_signal(signum, handler)
@@ -128,7 +128,7 @@ def teardown(self) -> None:
128128
@staticmethod
129129
def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]:
130130
"""Collects the currently assigned signal handlers."""
131-
valid_signals = SignalConnector._valid_signals()
131+
valid_signals = _SignalConnector._valid_signals()
132132
if not _IS_WINDOWS:
133133
# SIGKILL and SIGSTOP are not allowed to be modified by the user
134134
valid_signals -= {signal.SIGKILL, signal.SIGSTOP}

src/lightning/pytorch/trainer/trainer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,17 @@
5151
from lightning.pytorch.trainer import call, setup
5252
from lightning.pytorch.trainer.configuration_validator import verify_loop_configurations
5353
from lightning.pytorch.trainer.connectors.accelerator_connector import (
54+
_AcceleratorConnector,
5455
_LITERAL_WARN,
5556
_PRECISION_INPUT,
5657
_PRECISION_INPUT_STR,
57-
AcceleratorConnector,
5858
)
59-
from lightning.pytorch.trainer.connectors.callback_connector import CallbackConnector
60-
from lightning.pytorch.trainer.connectors.checkpoint_connector import CheckpointConnector
61-
from lightning.pytorch.trainer.connectors.data_connector import DataConnector
62-
from lightning.pytorch.trainer.connectors.logger_connector import LoggerConnector
59+
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
60+
from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector
61+
from lightning.pytorch.trainer.connectors.data_connector import _DataConnector
62+
from lightning.pytorch.trainer.connectors.logger_connector import _LoggerConnector
6363
from lightning.pytorch.trainer.connectors.logger_connector.result import _OUT_DICT, _PBAR_DICT, _ResultCollection
64-
from lightning.pytorch.trainer.connectors.signal_connector import SignalConnector
64+
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
6565
from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
6666
from lightning.pytorch.utilities import GradClipAlgorithmType, parsing
6767
from lightning.pytorch.utilities.argparse import _defaults_from_env_vars
@@ -388,9 +388,9 @@ def __init__(
388388
num_sanity_val_steps = 2
389389

390390
# init connectors
391-
self._data_connector = DataConnector(self)
391+
self._data_connector = _DataConnector(self)
392392

393-
self._accelerator_connector = AcceleratorConnector(
393+
self._accelerator_connector = _AcceleratorConnector(
394394
devices=devices,
395395
accelerator=accelerator,
396396
strategy=strategy,
@@ -402,10 +402,10 @@ def __init__(
402402
precision=precision,
403403
plugins=plugins,
404404
)
405-
self._logger_connector = LoggerConnector(self)
406-
self._callback_connector = CallbackConnector(self)
407-
self._checkpoint_connector = CheckpointConnector(self)
408-
self._signal_connector = SignalConnector(self)
405+
self._logger_connector = _LoggerConnector(self)
406+
self._callback_connector = _CallbackConnector(self)
407+
self._checkpoint_connector = _CheckpointConnector(self)
408+
self._signal_connector = _SignalConnector(self)
409409

410410
# init loops
411411
self.fit_loop = _FitLoop(self, min_epochs=min_epochs, max_epochs=max_epochs)

tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_trainer_save_checkpoint_storage_options(tmpdir, xla_available):
9494
io_mock.assert_called_with(ANY, instance_path, storage_options=None)
9595

9696
with mock.patch(
97-
"lightning.pytorch.trainer.connectors.checkpoint_connector.CheckpointConnector.save_checkpoint"
97+
"lightning.pytorch.trainer.connectors.checkpoint_connector._CheckpointConnector.save_checkpoint"
9898
) as cc_mock:
9999
trainer.save_checkpoint(instance_path, True)
100100
cc_mock.assert_called_with(instance_path, weights_only=True, storage_options=None)

0 commit comments

Comments
 (0)