diff --git a/src/lightning/pytorch/accelerators/accelerator.py b/src/lightning/pytorch/accelerators/accelerator.py index 9238071178a80..bed925b1c013c 100644 --- a/src/lightning/pytorch/accelerators/accelerator.py +++ b/src/lightning/pytorch/accelerators/accelerator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC -from typing import Any +from typing import Any, Optional import lightning.pytorch as pl from lightning.fabric.accelerators.accelerator import Accelerator as _Accelerator @@ -45,3 +45,8 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: """ raise NotImplementedError + + @classmethod + def device_name(cls, device: Optional[_DEVICE] = None) -> str: + """Get the device name for a given device.""" + return str(cls.is_available()) diff --git a/src/lightning/pytorch/accelerators/cuda.py b/src/lightning/pytorch/accelerators/cuda.py index a00b12a85a8dd..bfc581201e506 100644 --- a/src/lightning/pytorch/accelerators/cuda.py +++ b/src/lightning/pytorch/accelerators/cuda.py @@ -113,6 +113,13 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @classmethod + @override + def device_name(cls, device: Optional[_DEVICE] = None) -> str: + if not cls.is_available(): + return "" + return torch.cuda.get_device_name(device) + def get_nvidia_gpu_stats(device: _DEVICE) -> dict[str, float]: # pragma: no-cover """Get GPU stats including memory, fan speed, and temperature from nvidia-smi. diff --git a/src/lightning/pytorch/accelerators/mps.py b/src/lightning/pytorch/accelerators/mps.py index f7674989cc721..7ad3c3304e84a 100644 --- a/src/lightning/pytorch/accelerators/mps.py +++ b/src/lightning/pytorch/accelerators/mps.py @@ -87,6 +87,14 @@ def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> No description=cls.__name__, ) + @classmethod + @override + def device_name(cls, device: Optional[_DEVICE] = None) -> str: + # todo: implement a better way to get the device name + if not cls.is_available(): + return "" + return "True (mps)" + # device metrics _VM_PERCENT = "M1_vm_percent" diff --git a/src/lightning/pytorch/accelerators/xla.py b/src/lightning/pytorch/accelerators/xla.py index 10726b505448c..b248ab176d9f4 100644 --- a/src/lightning/pytorch/accelerators/xla.py +++ b/src/lightning/pytorch/accelerators/xla.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any +from typing import Any, Optional from typing_extensions import override from lightning.fabric.accelerators import _AcceleratorRegistry +from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 from lightning.fabric.accelerators.xla import XLAAccelerator as FabricXLAAccelerator from lightning.fabric.utilities.types import _DEVICE from lightning.pytorch.accelerators.accelerator import Accelerator @@ -53,3 +54,25 @@ def get_device_stats(self, device: _DEVICE) -> dict[str, Any]: @override def register_accelerators(cls, accelerator_registry: _AcceleratorRegistry) -> None: accelerator_registry.register("tpu", cls, description=cls.__name__) + + @classmethod + @override + def device_name(cls, device: Optional[_DEVICE] = None) -> str: + is_available = cls.is_available() + if not is_available: + return "" + + if _XLA_GREATER_EQUAL_2_1: + from torch_xla._internal import tpu + else: + from torch_xla.experimental import tpu + import torch_xla.core.xla_env_vars as xenv + from requests.exceptions import HTTPError + + try: + ret = tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE] + except HTTPError: + # Fallback to "True" if HTTPError is raised during retrieving device information + ret = str(is_available) + + return ret diff --git a/src/lightning/pytorch/trainer/setup.py b/src/lightning/pytorch/trainer/setup.py index 73591d30417b8..95cf30860ee2d 100644 --- a/src/lightning/pytorch/trainer/setup.py +++ b/src/lightning/pytorch/trainer/setup.py @@ -152,20 +152,23 @@ def _init_profiler(trainer: "pl.Trainer", profiler: Optional[Union[Profiler, str def _log_device_info(trainer: "pl.Trainer") -> None: if CUDAAccelerator.is_available(): - gpu_available = True - gpu_type = " (cuda)" + if isinstance(trainer.accelerator, CUDAAccelerator): + device_name = ", ".join(list({CUDAAccelerator.device_name(d) for d in trainer.device_ids})) + else: + device_name = CUDAAccelerator.device_name() elif MPSAccelerator.is_available(): - gpu_available = True - gpu_type = " (mps)" + device_name = MPSAccelerator.device_name() else: - gpu_available = False - gpu_type = "" + device_name = str(False) - gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) - rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}") + gpu_used = trainer.num_devices if isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator)) else 0 + rank_zero_info(f"GPU available: {device_name}, using: {gpu_used} devices.") num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, XLAAccelerator) else 0 - rank_zero_info(f"TPU available: {XLAAccelerator.is_available()}, using: {num_tpu_cores} TPU cores") + rank_zero_info( + f"TPU available: {XLAAccelerator.device_name() if XLAAccelerator.is_available() else str(False)}, " + f"using: {num_tpu_cores} TPU cores" + ) if _habana_available_and_importable(): from lightning_habana import HPUAccelerator diff --git a/tests/tests_pytorch/accelerators/test_gpu.py b/tests/tests_pytorch/accelerators/test_gpu.py index 5a71887e17eec..9312685847e1f 100644 --- a/tests/tests_pytorch/accelerators/test_gpu.py +++ b/tests/tests_pytorch/accelerators/test_gpu.py @@ -68,3 +68,16 @@ def test_gpu_availability(): def test_warning_if_gpus_not_used(cuda_count_1): with pytest.warns(UserWarning, match="GPU available but not used"): Trainer(accelerator="cpu") + + +@RunIf(min_cuda_gpus=1) +def test_gpu_device_name(): + for i in range(torch.cuda.device_count()): + assert torch.cuda.get_device_name(i) == CUDAAccelerator.device_name(i) + + with torch.device("cuda:0"): + assert torch.cuda.get_device_name(0) == CUDAAccelerator.device_name() + + +def test_gpu_device_name_no_gpu(cuda_count_0): + assert CUDAAccelerator.device_name() == "" diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py index c0a28840f0ef6..c1c267b2bc1e9 100644 --- a/tests/tests_pytorch/accelerators/test_mps.py +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -13,6 +13,7 @@ # limitations under the License. from collections import namedtuple +from unittest import mock import pytest import torch @@ -39,6 +40,16 @@ def test_mps_availability(): assert MPSAccelerator.is_available() +@RunIf(mps=True) +def test_mps_device_name(): + assert MPSAccelerator.device_name() == "True (mps)" + + +def test_mps_device_name_not_available(): + with mock.patch("torch.backends.mps.is_available", return_value=False): + assert MPSAccelerator.device_name() == "" + + def test_warning_if_mps_not_used(mps_count_1): with pytest.warns(UserWarning, match="GPU available but not used"): Trainer(accelerator="cpu") diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 83dace719371d..e156a22a35c3c 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -302,6 +302,19 @@ def test_warning_if_tpus_not_used(tpu_available): Trainer(accelerator="cpu") +@RunIf(tpu=True) +def test_tpu_device_name(): + from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 + + if _XLA_GREATER_EQUAL_2_1: + from torch_xla._internal import tpu + else: + from torch_xla.experimental import tpu + import torch_xla.core.xla_env_vars as xenv + + assert XLAAccelerator.device_name() == tpu.get_tpu_env()[xenv.ACCELERATOR_TYPE] + + @pytest.mark.parametrize( ("devices", "expected_device_ids"), [ diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index 878298c6bfd94..f7f1d890f5103 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -182,6 +182,7 @@ def thread_police_duuu_daaa_duuu_daaa(): def mock_cuda_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.fabric.accelerators.cuda, "num_cuda_devices", lambda: n) monkeypatch.setattr(lightning.pytorch.accelerators.cuda, "num_cuda_devices", lambda: n) + monkeypatch.setattr(torch.cuda, "get_device_name", lambda _: "Mocked CUDA Device") @pytest.fixture @@ -244,6 +245,11 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setitem(sys.modules, "torch_xla", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.core.xla_model", Mock()) monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) + monkeypatch.setattr( + lightning.pytorch.accelerators.xla.XLAAccelerator, + "device_name", + lambda *_: "Mocked TPU Device", + ) @pytest.fixture diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index 08bd1707b5cfd..f74b771199aaa 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -66,7 +66,7 @@ def test_ranks_available_manual_strategy_selection(_, strategy_cls): """Test that the rank information is readily available after Trainer initialization.""" num_nodes = 2 for cluster, variables, expected in environment_combinations(): - with mock.patch.dict(os.environ, variables): + with mock.patch.dict(os.environ, variables), mock.patch("torch.cuda.get_device_name", return_value="GPU"): strategy = strategy_cls( parallel_devices=[torch.device("cuda", 1), torch.device("cuda", 2)], cluster_environment=cluster )