|
12 | 12 | # limitations under the License.
|
13 | 13 | """Utilities that can be used with distributed training."""
|
14 | 14 |
|
15 |
| -from typing import Any, Callable, Dict, Optional, Tuple |
| 15 | +from typing import Any, Callable, Dict, Optional |
16 | 16 |
|
17 | 17 | import torch
|
18 |
| -from torch import Tensor |
19 | 18 | from torch.nn.parallel.distributed import DistributedDataParallel
|
20 | 19 |
|
21 | 20 | from lightning_fabric.utilities.distributed import _all_gather_ddp_if_available as new_all_gather_ddp_if_available
|
@@ -177,40 +176,6 @@ def gather_all_tensors(*args: Any, **kwargs: Any) -> Any:
|
177 | 176 | return new_gather_all_tensors(*args, **kwargs)
|
178 | 177 |
|
179 | 178 |
|
180 |
| -class AllGatherGrad(torch.autograd.Function): |
181 |
| - """Gathers tensors from the whole group and stacks them. |
182 |
| -
|
183 |
| - This implementation is copied from PyTorch. |
184 |
| -
|
185 |
| - .. deprecated:: v1.8.0 |
186 |
| - This function has been deprecated in v1.8.0 in favor of :func:`torch.distributed.nn.functional.all_gather` and |
187 |
| - will be removed in v2.0.0. |
188 |
| - """ |
189 |
| - |
190 |
| - @staticmethod |
191 |
| - def forward( # type: ignore[override] |
192 |
| - ctx: Any, |
193 |
| - tensor: Tensor, |
194 |
| - group: Optional["torch.distributed.ProcessGroup"] = None, |
195 |
| - ) -> Tensor: |
196 |
| - rank_zero_deprecation( |
197 |
| - "`AllGatherGrad` has been deprecated in v1.8.0 and will be removed in v2.0.0." |
198 |
| - " Use `torch.distributed.nn.functional.all_gather` instead.", |
199 |
| - stacklevel=6, |
200 |
| - ) |
201 |
| - ctx.group = group |
202 |
| - gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] |
203 |
| - torch.distributed.all_gather(gathered_tensor, tensor, group=group) |
204 |
| - gathered_tensor = torch.stack(gathered_tensor, dim=0) |
205 |
| - return gathered_tensor |
206 |
| - |
207 |
| - @staticmethod |
208 |
| - def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]: |
209 |
| - grad_output = torch.cat(grad_output) |
210 |
| - torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) |
211 |
| - return grad_output[torch.distributed.get_rank()], None |
212 |
| - |
213 |
| - |
214 | 179 | def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any:
|
215 | 180 | rank_zero_deprecation(
|
216 | 181 | "`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated"
|
|
0 commit comments