|
23 | 23 | from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment
|
24 | 24 | from lightning.fabric.utilities.distributed import group as _group
|
25 | 25 | from lightning.pytorch.accelerators.hpu import _HPU_AVAILABLE
|
26 |
| -from lightning.pytorch.overrides.torch_distributed import broadcast_object_list |
27 | 26 | from lightning.pytorch.plugins.io.hpu_plugin import HPUCheckpointIO
|
28 | 27 | from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
|
29 | 28 | from lightning.pytorch.plugins.precision import PrecisionPlugin
|
@@ -106,7 +105,7 @@ def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore
|
106 | 105 | if self.global_rank != src:
|
107 | 106 | obj = [None]
|
108 | 107 |
|
109 |
| - broadcast_object_list(obj, src, group=_group.WORLD) |
| 108 | + _hpu_broadcast_object_list(obj, src, group=_group.WORLD) |
110 | 109 | return obj[0]
|
111 | 110 |
|
112 | 111 | def on_after_backward(self) -> None:
|
@@ -138,3 +137,80 @@ def teardown(self) -> None:
|
138 | 137 | # Was set to local rank
|
139 | 138 | os.environ.pop("ID", None)
|
140 | 139 | os.environ.pop("HCCL_DISTRIBUTED_BACKEND", None)
|
| 140 | + |
| 141 | + |
| 142 | +# The code underneath is taken from PyTorch `torch/distributed/distributed_c10d.py` |
| 143 | +# the distributed backend and tensor type updates for habana backend is done here before broadcast |
| 144 | +def _hpu_broadcast_object_list(object_list, src=0, group=None, device=None): # type: ignore |
| 145 | + from torch.distributed import _rank_not_in_group, Backend, broadcast, get_backend, get_rank |
| 146 | + from torch.distributed.distributed_c10d import _object_to_tensor, _tensor_to_object |
| 147 | + |
| 148 | + if _rank_not_in_group(group): |
| 149 | + return |
| 150 | + |
| 151 | + my_rank = get_rank() |
| 152 | + # Serialize object_list elements to tensors on src rank. |
| 153 | + if my_rank == src: |
| 154 | + tensor_list, size_list = zip(*[_object_to_tensor(obj, device) for obj in object_list]) |
| 155 | + object_sizes_tensor = torch.cat(size_list) |
| 156 | + else: |
| 157 | + object_sizes_tensor = torch.empty(len(object_list), dtype=torch.long) |
| 158 | + |
| 159 | + # Current device selection. |
| 160 | + # To preserve backwards compatibility, ``device`` is default to ``None`` |
| 161 | + # in which case we run current logic of device selection, i.e. |
| 162 | + # ``current_device`` is CUDA if backend is NCCL otherwise CPU device. In the |
| 163 | + # case it is not ``None`` we move the size and object tensors to be |
| 164 | + # broadcasted to this device. |
| 165 | + group_backend = get_backend(group) |
| 166 | + is_nccl_backend = group_backend == Backend.NCCL |
| 167 | + is_hpu_backend = os.environ.get("HCCL_DISTRIBUTED_BACKEND") == "1" |
| 168 | + if device is not None: |
| 169 | + if is_nccl_backend and device.type != "cuda": |
| 170 | + raise ValueError("device type must be cuda for nccl backend") |
| 171 | + current_device = device |
| 172 | + else: |
| 173 | + current_device = torch.device("cpu") |
| 174 | + if is_nccl_backend: |
| 175 | + # See note about using torch.cuda.current_device() here in |
| 176 | + # docstring. We cannot simply use my_rank since rank == device is |
| 177 | + # not necessarily true. |
| 178 | + current_device = torch.device("cuda", torch.cuda.current_device()) |
| 179 | + if is_nccl_backend: |
| 180 | + object_sizes_tensor = object_sizes_tensor.to(current_device) |
| 181 | + |
| 182 | + elif is_hpu_backend: |
| 183 | + current_device = torch.device("hpu") |
| 184 | + # Workaround: HPU doesn't not support long tensors for collectives |
| 185 | + if (object_sizes_tensor.type() == "torch.LongTensor") or (object_sizes_tensor.type() == "torch.hpu.LongTensor"): |
| 186 | + object_sizes_tensor = object_sizes_tensor.int() |
| 187 | + else: |
| 188 | + print("unhandled hpu object_sizes_tensor type :: ", object_sizes_tensor.type()) |
| 189 | + object_sizes_tensor = object_sizes_tensor.to(current_device) |
| 190 | + |
| 191 | + # Broadcast object sizes |
| 192 | + broadcast(object_sizes_tensor, src=src, group=group) |
| 193 | + |
| 194 | + # Concatenate and broadcast serialized object tensors |
| 195 | + if my_rank == src: |
| 196 | + object_tensor = torch.cat(tensor_list) |
| 197 | + else: |
| 198 | + object_tensor = torch.empty( |
| 199 | + torch.sum(object_sizes_tensor).int().item(), |
| 200 | + dtype=torch.uint8, |
| 201 | + ) |
| 202 | + |
| 203 | + if is_nccl_backend or is_hpu_backend: |
| 204 | + object_tensor = object_tensor.to(current_device) |
| 205 | + |
| 206 | + broadcast(object_tensor, src=src, group=group) |
| 207 | + # Deserialize objects using their stored sizes. |
| 208 | + offset = 0 |
| 209 | + if my_rank != src: |
| 210 | + for i, obj_size in enumerate(object_sizes_tensor): |
| 211 | + obj_view = object_tensor[offset : offset + obj_size] |
| 212 | + obj_view = obj_view.type(torch.uint8) |
| 213 | + if obj_view.device != torch.device("cpu"): |
| 214 | + obj_view = obj_view.cpu() |
| 215 | + offset += obj_size |
| 216 | + object_list[i] = _tensor_to_object(obj_view, obj_size) |
0 commit comments