Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch

from lightning_lite.accelerators.accelerator import Accelerator
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_14
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCH_GREATER_EQUAL_1_14


class CUDAAccelerator(Accelerator):
Expand Down Expand Up @@ -78,11 +78,12 @@ def _get_all_available_cuda_gpus() -> List[int]:
return list(range(num_cuda_devices()))


# TODO: Remove once minimum supported PyTorch version is 1.14
@contextmanager
def _patch_cuda_is_available() -> Generator:
"""Context manager that safely patches :func:`torch.cuda.is_available` with its NVML-based version if
possible."""
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0:
if hasattr(torch._C, "_cuda_getDeviceCount") and _device_count_nvml() >= 0 and not _TORCH_GREATER_EQUAL_1_14:
# we can safely patch is_available if both torch has CUDA compiled and the NVML count is succeeding
# otherwise, patching is_available could lead to attribute errors or infinite recursion
orig_check = torch.cuda.is_available
Expand All @@ -102,12 +103,11 @@ def num_cuda_devices() -> int:
Unlike :func:`torch.cuda.device_count`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
"""
if _TORCH_GREATER_EQUAL_1_14:
# We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning_lite.__init__.py
if _TORCH_GREATER_EQUAL_1_13:
return torch.cuda.device_count()

# Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879
# TODO: Remove once minimum supported PyTorch version is 1.14
# TODO: Remove once minimum supported PyTorch version is 1.13
nvml_count = _device_count_nvml()
return torch.cuda.device_count() if nvml_count < 0 else nvml_count

Expand All @@ -118,9 +118,11 @@ def is_cuda_available() -> bool:
Unlike :func:`torch.cuda.is_available`, this function does its best not to create a CUDA context for fork support,
if the platform allows it.
"""
return num_cuda_devices() > 0
# We set `PYTORCH_NVML_BASED_CUDA_CHECK=1` in lightning_lite.__init__.py
return torch.cuda.is_available() if _TORCH_GREATER_EQUAL_1_14 else num_cuda_devices() > 0


# TODO: Remove once minimum supported PyTorch version is 1.13
def _parse_visible_devices() -> Set[int]:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
var = os.getenv("CUDA_VISIBLE_DEVICES")
Expand All @@ -146,6 +148,7 @@ def _strtoul(s: str) -> int:
return rc


# TODO: Remove once minimum supported PyTorch version is 1.13
def _raw_device_count_nvml() -> int:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
from ctypes import c_int, CDLL
Expand All @@ -164,6 +167,7 @@ def _raw_device_count_nvml() -> int:
return dev_arr[0]


# TODO: Remove once minimum supported PyTorch version is 1.13
def _device_count_nvml() -> int:
"""Implementation copied from upstream: https://github.com/pytorch/pytorch/pull/84879."""
try:
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `Callback.on_load_checkpoint` now gets the full checkpoint dictionary and the `callback_state` argument was renamed `checkpoint` ([#14835](https://github.com/Lightning-AI/lightning/pull/14835))


- From PyTorch 1.14 and higher, Lightning will configure PyTorch to use a NVML-based check for `torch.cuda.is_available` and `torch.cuda.device_count` to avoid issues with forking processes ([#15110](https://github.com/Lightning-AI/lightning/pull/15110))
- To avoid issues with forking processes, from PyTorch 1.13 and higher, Lightning will directly use the PyTorch NVML-based check for `torch.cuda.device_count` and from PyTorch 1.14 and higher, Lightning will configure PyTorch to use a NVML-based check for `torch.cuda.is_available`. ([#15110](https://github.com/Lightning-AI/lightning/pull/15110), [#15133](https://github.com/Lightning-AI/lightning/pull/15133))


### Deprecated
Expand Down