diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index ec4eb261f2d3e..500f3a3e2aa92 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None: def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + """Return corresponding distributed backend for a given device.""" + device_backend_map = torch.distributed.Backend.default_device_backend_map + if device.type in device_backend_map: + return device_backend_map[device.type] + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index d65eaa810ff4d..51c4b320d5525 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -17,6 +17,7 @@ from lightning.fabric.utilities.distributed import ( _destroy_dist_connection, _gather_all_tensors, + _get_default_process_group_backend_for_device, _InfiniteBarrier, _init_dist_connection, _is_dtensor, @@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock): atexit_mock.register.assert_not_called() +def test_get_default_process_group_backend_for_device(): + """Test that each device type maps to its correct default process group backend.""" + # register a custom backend for test + torch.utils.rename_privateuse1_backend("pcu") + + def mock_backend(store, group_rank, group_size, timeout): + pass + + torch.distributed.Backend.register_backend( + "pccl", + lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout), + devices=["pcu"], + ) + + # test that the default backend is correctly set for each device + devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")] + backends = ["gloo", "nccl", "pccl"] + for device, backend in zip(devices, backends): + assert _get_default_process_group_backend_for_device(device) == backend + + @RunIf(min_torch="2.4") def test_is_dtensor(monkeypatch): from torch.distributed._tensor import DTensor