Skip to content
24 changes: 23 additions & 1 deletion src/lightning_lite/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.strategy import Strategy
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.imports import _IS_INTERACTIVE, _TORCH_GREATER_EQUAL_1_11
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states


Expand Down Expand Up @@ -86,6 +86,9 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
*args: Optional positional arguments to be passed to the given function.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
Expand Down Expand Up @@ -176,3 +179,22 @@ def restore(self) -> None:
def _is_forking_disabled() -> bool:
"""Returns whether forking is disabled through the environment variable ``PL_DISABLE_FORK``."""
return bool(int(os.environ.get("PL_DISABLE_FORK", "0")))


def _check_bad_cuda_fork() -> None:
"""Checks whether it is safe to fork and initialize CUDA in the new processes, and raises an exception if not.

The error message replaces PyTorch's 'Cannot re-initialize CUDA in forked subprocess' with helpful advice for
Lightning users.
"""
if not torch.cuda.is_initialized():
return

message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python session; in a notebook, that means restart the kernel."
raise RuntimeError(message)
4 changes: 4 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `WandbLogger.download_artifact` and `WandbLogger.use_artifact` for managing artifacts with Weights and Biases ([#14551](https://github.com/Lightning-AI/lightning/issues/14551))


- Added a friendlier error message when attempting to fork processes with pre-initialized CUDA context ([#14709](https://github.com/Lightning-AI/lightning/issues/14709))



### Changed

- The `Trainer.{fit,validate,test,predict,tune}` methods now raise a useful error message if the input is not a `LightningModule` ([#13892](https://github.com/Lightning-AI/lightning/pull/13892))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import pytorch_lightning as pl
from lightning_lite.strategies.launchers.base import _Launcher
from lightning_lite.strategies.launchers.multiprocessing import _check_bad_cuda_fork
from lightning_lite.utilities.apply_func import move_data_to_device
from lightning_lite.utilities.seed import _collect_rng_states, _set_rng_states
from lightning_lite.utilities.types import _PATH
Expand Down Expand Up @@ -94,6 +95,9 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
**kwargs: Optional keyword arguments to be passed to the given function.
"""
self._check_torchdistx_support()
if self._start_method in ("fork", "forkserver"):
_check_bad_cuda_fork()

# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
Expand Down
10 changes: 10 additions & 0 deletions tests/tests_lite/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,13 @@ def test_global_state_snapshot():
assert torch.are_deterministic_algorithms_enabled()
assert not torch.backends.cudnn.benchmark
assert torch.initial_seed() == 123


@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch("torch.cuda.is_initialized", return_value=True)
@mock.patch("lightning_lite.strategies.launchers.multiprocessing.mp")
def test_multiprocessing_launcher_check_for_bad_cuda_fork(mp_mock, _, start_method):
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Lightning can't create new processes if CUDA is already initialized"):
launcher.launch(function=Mock())