|
19 | 19 | _gather_all_tensors,
|
20 | 20 | _InfiniteBarrier,
|
21 | 21 | _init_dist_connection,
|
| 22 | + _get_default_process_group_backend_for_device, |
22 | 23 | _is_dtensor,
|
23 | 24 | _set_num_threads_if_needed,
|
24 | 25 | _suggested_max_num_threads,
|
@@ -242,6 +243,24 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
|
242 | 243 | atexit_mock.register.assert_not_called()
|
243 | 244 |
|
244 | 245 |
|
| 246 | +def test_get_default_process_group_backend_for_device(): |
| 247 | + # register a custom backend for test |
| 248 | + torch.utils.rename_privateuse1_backend("pcu") |
| 249 | + def mock_backend(store, group_rank, group_size, timeout): |
| 250 | + pass |
| 251 | + torch.distributed.Backend.register_backend( |
| 252 | + "pccl", |
| 253 | + lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout |
| 254 | + ), |
| 255 | + devices=["pcu"]) |
| 256 | + |
| 257 | + # test that the default backend is correctly set for each device |
| 258 | + devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("xpu:0"), torch.device("pcu:0")] |
| 259 | + backends = ["gloo", "nccl", "xccl", "pccl"] |
| 260 | + for device, backend in zip(devices, backends): |
| 261 | + assert _get_default_process_group_backend_for_device(device) == backend |
| 262 | + |
| 263 | + |
245 | 264 | @RunIf(min_torch="2.4")
|
246 | 265 | def test_is_dtensor(monkeypatch):
|
247 | 266 | from torch.distributed._tensor import DTensor
|
|
0 commit comments