Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions docs/source-pytorch/accelerators/gpu_basic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,25 @@ The table below lists examples of possible input formats and how they are interp
| "-1" | str | [0, 1, 2, ...] | all available GPUs |
+------------------+-----------+---------------------+---------------------------------+

.. note::

When specifying number of ``devices`` as an integer ``devices=k``, setting the trainer flag
``auto_select_gpus=True`` will automatically help you find ``k`` GPUs that are not
occupied by other processes. This is especially useful when GPUs are configured
to be in "exclusive mode", such that only one process at a time can access them.
For more details see the :doc:`trainer guide <../common/trainer>`.
Find usable CUDA devices
^^^^^^^^^^^^^^^^^^^^^^^^

If you want to run several experiments at the same time on your machine, for example for a hyperparameter sweep, then you can
use the following utility function to pick GPU indices that are "accessible", without having to change your code every time.

.. code-block:: python

from lightning.pytorch.accelerators import find_usable_cuda_devices

# Find two GPUs on the system that are not already occupied
trainer = Trainer(accelerator="cuda", devices=find_usable_cuda_devices(2))

from lightning.lite.accelerators import find_usable_cuda_devices

# Works with LightningLite too
lite = LightningLite(accelerator="cuda", devices=find_usable_cuda_devices(2))


This is especially useful when GPUs are configured to be in "exclusive compute mode", such that only one process at a time is allowed access to the device.
This special mode is often enabled on server GPUs or systems shared among multiple users.
28 changes: 0 additions & 28 deletions docs/source-pytorch/common/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -314,34 +314,6 @@ before any training.
# call tune to find the batch size
trainer.tune(model)

auto_select_gpus
^^^^^^^^^^^^^^^^

.. raw:: html

<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/auto_select+_gpus.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/auto_select_gpus.mp4"></video>

|

If enabled and ``devices`` is an integer, pick available GPUs automatically.
This is especially useful when GPUs are configured to be in "exclusive mode",
such that only one process at a time can access them.

Example::

# no auto selection (picks first 2 GPUs on system, may fail if other process is occupying)
trainer = Trainer(accelerator="gpu", devices=2, auto_select_gpus=False)

# enable auto selection (will find two available GPUs on system)
trainer = Trainer(accelerator="gpu", devices=2, auto_select_gpus=True)

# specifies all GPUs regardless of its availability
Trainer(accelerator="gpu", devices=-1, auto_select_gpus=False)

# specifies all available GPUs (if only one GPU is not occupied, uses one gpu)
Trainer(accelerator="gpu", devices=-1, auto_select_gpus=True)

auto_lr_find
^^^^^^^^^^^^
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_lite/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for Fully Sharded Data Parallel (FSDP) training in Lightning Lite ([#14967](https://github.com/Lightning-AI/lightning/issues/14967))


- Added `lightning_lite.accelerators.find_usable_cuda_devices` utility function ([#16147](https://github.com/PyTorchLightning/pytorch-lightning/pull/16147))


### Changed

- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))
Expand Down
1 change: 1 addition & 0 deletions src/lightning_lite/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lightning_lite.accelerators.accelerator import Accelerator # noqa: F401
from lightning_lite.accelerators.cpu import CPUAccelerator # noqa: F401
from lightning_lite.accelerators.cuda import CUDAAccelerator # noqa: F401
from lightning_lite.accelerators.cuda import find_usable_cuda_devices # noqa: F401
from lightning_lite.accelerators.mps import MPSAccelerator # noqa: F401
from lightning_lite.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from lightning_lite.accelerators.tpu import TPUAccelerator # noqa: F401
Expand Down
58 changes: 55 additions & 3 deletions src/lightning_lite/accelerators/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,62 @@ def register_accelerators(cls, accelerator_registry: Dict) -> None:
)


def _get_all_available_cuda_gpus() -> List[int]:
def find_usable_cuda_devices(num_devices: int = -1) -> List[int]:
"""Returns a list of all available and usable CUDA GPU devices.

A GPU is considered usable if we can successfully move a tensor to the device, and this is what this function
tests for each GPU on the system until the target number of usable devices is found.

A subset of GPUs on the system might be used by other processes, and if the GPU is configured to operate in
'exclusive' mode (configurable by the admin), then only one process is allowed to occupy it.

Args:
num_devices: The number of devices you want to request. By default, this function will return as many as there
are usable CUDA GPU devices available.

Warning:
If multiple processes call this function at the same time, there can be race conditions in the case where
both processes determine that the device is unoccupied, leading into one of them crashing later on.
"""
Returns:
A list of all available CUDA GPUs
visible_devices = _get_all_visible_cuda_devices()
if not visible_devices:
raise ValueError(
f"You requested to find {num_devices} devices but there are no visible CUDA devices on this machine."
)
if num_devices > len(visible_devices):
raise ValueError(
f"You requested to find {num_devices} devices but this machine only has {len(visible_devices)} GPUs."
)

available_devices = []
unavailable_devices = []

for gpu_idx in visible_devices:
try:
torch.tensor(0, device=torch.device("cuda", gpu_idx))
except RuntimeError:
unavailable_devices.append(gpu_idx)
continue

available_devices.append(gpu_idx)
if len(available_devices) == num_devices:
# exit early if we found the right number of GPUs
break

if len(available_devices) != num_devices:
raise RuntimeError(
f"You requested to find {num_devices} devices but only {len(available_devices)} are currently available."
f" The devices {unavailable_devices} are occupied by other processes and can't be used at the moment."
)
return available_devices


def _get_all_visible_cuda_devices() -> List[int]:
"""Returns a list of all visible CUDA GPU devices.

Devices masked by the environment variabale ``CUDA_VISIBLE_DEVICES`` won't be returned here. For example, assume you
have 8 physical GPUs. If ``CUDA_VISIBLE_DEVICES="1,3,6"``, then this function will return the list ``[0, 1, 2]``
because these are the three visible GPUs after applying the mask ``CUDA_VISIBLE_DEVICES``.
"""
return list(range(num_cuda_devices()))

Expand Down
2 changes: 1 addition & 1 deletion src/lightning_lite/utilities/device_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def _get_all_available_gpus(include_cuda: bool = False, include_mps: bool = Fals
Returns:
A list of all available GPUs
"""
cuda_gpus = accelerators.cuda._get_all_available_cuda_gpus() if include_cuda else []
cuda_gpus = accelerators.cuda._get_all_visible_cuda_devices() if include_cuda else []
mps_gpus = accelerators.mps._get_all_available_mps_gpus() if include_mps else []
return cuda_gpus + mps_gpus

Expand Down
6 changes: 6 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.profiler` in favor of `pytorch_lightning.profilers` ([#16059](https://github.com/PyTorchLightning/pytorch-lightning/pull/16059))


- Deprecated `Trainer(auto_select_gpus=...)` in favor of `pytorch_lightning.accelerators.find_usable_cuda_devices` ([#16147](https://github.com/PyTorchLightning/pytorch-lightning/pull/16147))


- Deprecated `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` in favor of `pytorch_lightning.accelerators.find_usable_cuda_devices` ([#16147](https://github.com/PyTorchLightning/pytorch-lightning/pull/16147))


- `nvidia/apex` deprecation ([#16039](https://github.com/PyTorchLightning/pytorch-lightning/pull/16039))
* Deprecated `pytorch_lightning.plugins.NativeMixedPrecisionPlugin` in favor of `pytorch_lightning.plugins.MixedPrecisionPlugin`
* Deprecated the `LightningModule.optimizer_step(using_native_amp=...)` argument
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/accelerators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lightning_lite.accelerators.cuda import find_usable_cuda_devices # noqa: F401
from lightning_lite.accelerators.registry import _AcceleratorRegistry, call_register_accelerators
from pytorch_lightning.accelerators.accelerator import Accelerator # noqa: F401
from pytorch_lightning.accelerators.cpu import CPUAccelerator # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(
benchmark: Optional[bool] = None,
replace_sampler_ddp: bool = True,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = False,
auto_select_gpus: bool = False,
auto_select_gpus: Optional[bool] = None, # TODO: Remove in v1.10.0
num_processes: Optional[int] = None, # deprecated
tpu_cores: Optional[Union[List[int], str, int]] = None, # deprecated
ipus: Optional[int] = None, # deprecated
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
self.checkpoint_io: Optional[CheckpointIO] = None
self._amp_type_flag: Optional[str] = None # TODO: Remove in v1.10.0
self._amp_level_flag: Optional[str] = amp_level # TODO: Remove in v1.10.0
self._auto_select_gpus: bool = auto_select_gpus
self._auto_select_gpus: Optional[bool] = auto_select_gpus

self._check_config_and_set_final_flags(
strategy=strategy,
Expand Down Expand Up @@ -558,8 +558,17 @@ def _set_devices_flag_if_auto_passed(self) -> None:
self._devices_flag = self.accelerator.auto_device_count()

def _set_devices_flag_if_auto_select_gpus_passed(self) -> None:
if self._auto_select_gpus is not None:
rank_zero_deprecation(
"The Trainer argument `auto_select_gpus` has been deprecated in v1.9.0 and will be removed in v1.10.0."
" Please use the function `pytorch_lightning.accelerators.find_usable_cuda_devices` instead."
)
if self._auto_select_gpus and isinstance(self._gpus, int) and isinstance(self.accelerator, CUDAAccelerator):
self._devices_flag = pick_multiple_gpus(self._gpus)
self._devices_flag = pick_multiple_gpus(
self._gpus,
# we already show a deprecation message when user sets Trainer(auto_select_gpus=...)
_show_deprecation=False,
)
log.info(f"Auto select gpus: {self._devices_flag}")

def _choose_and_init_cluster_environment(self) -> ClusterEnvironment:
Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def __init__(
num_processes: Optional[int] = None, # TODO: Remove in 2.0
devices: Optional[Union[List[int], str, int]] = None,
gpus: Optional[Union[List[int], str, int]] = None, # TODO: Remove in 2.0
auto_select_gpus: bool = False,
auto_select_gpus: Optional[bool] = None, # TODO: Remove in 2.0
tpu_cores: Optional[Union[List[int], str, int]] = None, # TODO: Remove in 2.0
ipus: Optional[int] = None, # TODO: Remove in 2.0
enable_progress_bar: bool = True,
Expand Down Expand Up @@ -210,6 +210,10 @@ def __init__(
that only one process at a time can access them.
Default: ``False``.

.. deprecated:: v1.9
``auto_select_gpus`` has been deprecated in v1.9.0 and will be removed in v1.10.0.
Please use the function :func:`~lightning_lite.accelerators.cuda.find_usable_cuda_devices` instead.

benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
(``False`` if not manually set). If :paramref:`~pytorch_lightning.trainer.Trainer.deterministic` is set
Expand Down
33 changes: 28 additions & 5 deletions src/pytorch_lightning/tuner/auto_gpu_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,27 @@

from lightning_lite.accelerators.cuda import num_cuda_devices
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation


def pick_multiple_gpus(nb: int) -> List[int]:
"""
def pick_multiple_gpus(nb: int, _show_deprecation: bool = True) -> List[int]:
"""Pick a number of GPUs that are not yet in use.

.. deprecated:: v1.9.0
The function ``pick_multiple_gpus`` has been deprecated in v1.9.0 and will be removed in v1.10.0.
Please use the function ``pytorch_lightning.accelerators.find_usable_cuda_devices`` instead.

Raises:
MisconfigurationException:
If ``gpus`` or ``devices`` is set to 0, when ``auto_select_gpus=True``, or when the requested number is
higher than the number of GPUs available on the machine.
"""
if _show_deprecation:
rank_zero_deprecation(
"The function `pick_multiple_gpus` has been deprecated in v1.9.0 and will be removed in v1.10.0."
" Please use the function `pytorch_lightning.accelerators.find_usable_cuda_devices` instead."
)

if nb == 0:
raise MisconfigurationException(
"auto_select_gpus=True, gpus=0 is not a valid configuration."
Expand All @@ -39,17 +51,28 @@ def pick_multiple_gpus(nb: int) -> List[int]:

picked: List[int] = []
for _ in range(nb):
picked.append(pick_single_gpu(exclude_gpus=picked))
picked.append(pick_single_gpu(exclude_gpus=picked, _show_deprecation=False))

return picked


def pick_single_gpu(exclude_gpus: List[int]) -> int:
"""
def pick_single_gpu(exclude_gpus: List[int], _show_deprecation: bool = True) -> int:
"""Find a GPU that is not yet in use.

.. deprecated:: v1.9.0
The function ``pick_single_gpu`` has been deprecated in v1.9.0 and will be removed in v1.10.0.
Please use the function ``pytorch_lightning.accelerators.find_usable_cuda_devices`` instead.

Raises:
RuntimeError:
If you try to allocate a GPU, when no GPUs are available.
"""
if _show_deprecation:
rank_zero_deprecation(
"The function `pick_single_gpu` has been deprecated in v1.9.0 and will be removed in v1.10.0."
" Please use the function `pytorch_lightning.accelerators.find_usable_cuda_devices` instead."
)

previously_used_gpus = []
unused_gpus = []
for i in range(num_cuda_devices()):
Expand Down
25 changes: 25 additions & 0 deletions tests/tests_lite/accelerators/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import importlib
import logging
import os
from re import escape
from unittest import mock
from unittest.mock import Mock

Expand All @@ -25,6 +26,7 @@
from lightning_lite.accelerators.cuda import (
_check_cuda_matmul_precision,
CUDAAccelerator,
find_usable_cuda_devices,
is_cuda_available,
num_cuda_devices,
)
Expand Down Expand Up @@ -114,3 +116,26 @@ def test_tf32_message(_, __, caplog):
with caplog.at_level(logging.INFO):
_check_cuda_matmul_precision(device)
assert expected in caplog.text


def test_find_usable_cuda_devices_error_handling():
"""Test error handling for edge cases when using `find_usable_cuda_devices`."""

# Asking for GPUs if no GPUs visible
with mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=0), pytest.raises(
ValueError, match="You requested to find 2 devices but there are no visible CUDA"
):
find_usable_cuda_devices(2)

# Asking for more GPUs than are visible
with mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=1), pytest.raises(
ValueError, match="this machine only has 1 GPUs"
):
find_usable_cuda_devices(2)

# All GPUs are unusable
tensor_mock = Mock(side_effect=RuntimeError) # simulate device placement fails
with mock.patch("lightning_lite.accelerators.cuda.num_cuda_devices", return_value=2), mock.patch(
"lightning_lite.accelerators.cuda.torch.tensor", tensor_mock
), pytest.raises(RuntimeError, match=escape("The devices [0, 1] are occupied by other processes")):
find_usable_cuda_devices(2)
22 changes: 22 additions & 0 deletions tests/tests_pytorch/deprecated_api/test_remove_1-10.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pytorch_lightning.profiler as profiler
from lightning_lite.accelerators import CUDAAccelerator as LiteCUDAAccelerator
from lightning_lite.accelerators import TPUAccelerator as LiteTPUAccelerator
from lightning_lite.utilities.exceptions import MisconfigurationException
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators.cpu import CPUAccelerator
from pytorch_lightning.cli import LightningCLI
Expand All @@ -38,6 +39,7 @@
from pytorch_lightning.strategies.bagua import LightningBaguaModule
from pytorch_lightning.strategies.utils import on_colab_kaggle
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.tuner.auto_gpu_select import pick_multiple_gpus, pick_single_gpu
from pytorch_lightning.utilities.apply_func import (
apply_to_collection,
apply_to_collections,
Expand Down Expand Up @@ -409,3 +411,23 @@ def optimizer_step(
def test_horovod_deprecation_warnings(*_):
with pytest.deprecated_call(match=r"horovod'\)` has been deprecated in v1.9"):
Trainer(strategy="horovod")


def test_auto_select_gpus():
with pytest.deprecated_call(match="The Trainer argument `auto_select_gpus` has been deprecated in v1.9.0"):
Trainer(auto_select_gpus=False)


def test_pick_multiple_gpus():
with pytest.deprecated_call(match="The function `pick_multiple_gpus` has been deprecated in v1.9.0"), pytest.raises(
MisconfigurationException
):
pick_multiple_gpus(0)


@mock.patch("pytorch_lightning.tuner.auto_gpu_select.num_cuda_devices", return_value=0)
def test_pick_single_gpu(_):
with pytest.deprecated_call(match="The function `pick_single_gpu` has been deprecated in v1.9.0"), pytest.raises(
RuntimeError
):
pick_single_gpu([])
Loading