Skip to content

Commit 116bd74

Browse files
committed
support more hardware platforms and no longer hard code cuda when call _get_default_process_group_backend_for_device
Signed-off-by: taozhiwei <[email protected]>
1 parent f067626 commit 116bd74

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

src/lightning/fabric/utilities/distributed.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
319319

320320

321321
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
322-
return "nccl" if device.type == "cuda" else "gloo"
322+
device_backend_map = torch.distributed.Backend.default_device_backend_map
323+
if device.type in device_backend_map:
324+
return device_backend_map[device.type]
325+
else:
326+
return "gloo"
323327

324328

325329
class _DatasetSamplerWrapper(Dataset):

tests/tests_fabric/utilities/test_distributed.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
_gather_all_tensors,
2020
_InfiniteBarrier,
2121
_init_dist_connection,
22+
_get_default_process_group_backend_for_device,
2223
_is_dtensor,
2324
_set_num_threads_if_needed,
2425
_suggested_max_num_threads,
@@ -242,6 +243,24 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
242243
atexit_mock.register.assert_not_called()
243244

244245

246+
def test_get_default_process_group_backend_for_device():
247+
# register a custom backend for test
248+
torch.utils.rename_privateuse1_backend("pcu")
249+
def mock_backend(store, group_rank, group_size, timeout):
250+
pass
251+
torch.distributed.Backend.register_backend(
252+
"pccl",
253+
lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout
254+
),
255+
devices=["pcu"])
256+
257+
# test that the default backend is correctly set for each device
258+
devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("xpu:0"), torch.device("pcu:0")]
259+
backends = ["gloo", "nccl", "xccl", "pccl"]
260+
for device, backend in zip(devices, backends):
261+
assert _get_default_process_group_backend_for_device(device) == backend
262+
263+
245264
@RunIf(min_torch="2.4")
246265
def test_is_dtensor(monkeypatch):
247266
from torch.distributed._tensor import DTensor

0 commit comments

Comments
 (0)