Skip to content

Commit 850e498

Browse files
authored
Remove the deprecated AllGatherGrad class (#16360)
1 parent b3b9486 commit 850e498

File tree

4 files changed

+3
-47
lines changed

4 files changed

+3
-47
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3535
* Removed the `Trainer(ipus=...)` argument
3636
* Removed the `Trainer(num_processes=...)` argument
3737

38+
- Removed the deprecated `pytorch_lightning.utilities.AllGatherGrad` class ([#16360](https://github.com/Lightning-AI/lightning/pull/16360))
39+
3840
- Removed the deprecated `resume_from_checkpoint` Trainer argument ([#16167](https://github.com/Lightning-AI/lightning/pull/16167))
3941

4042
- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))

src/pytorch_lightning/utilities/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from lightning_fabric.utilities import LightningEnum # noqa: F401
1919
from lightning_fabric.utilities import move_data_to_device # noqa: F401
20-
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
2120
from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401
2221
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
2322
from pytorch_lightning.utilities.imports import ( # noqa: F401

src/pytorch_lightning/utilities/distributed.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,9 @@
1212
# limitations under the License.
1313
"""Utilities that can be used with distributed training."""
1414

15-
from typing import Any, Callable, Dict, Optional, Tuple
15+
from typing import Any, Callable, Dict, Optional
1616

1717
import torch
18-
from torch import Tensor
1918
from torch.nn.parallel.distributed import DistributedDataParallel
2019

2120
from lightning_fabric.utilities.distributed import _all_gather_ddp_if_available as new_all_gather_ddp_if_available
@@ -177,40 +176,6 @@ def gather_all_tensors(*args: Any, **kwargs: Any) -> Any:
177176
return new_gather_all_tensors(*args, **kwargs)
178177

179178

180-
class AllGatherGrad(torch.autograd.Function):
181-
"""Gathers tensors from the whole group and stacks them.
182-
183-
This implementation is copied from PyTorch.
184-
185-
.. deprecated:: v1.8.0
186-
This function has been deprecated in v1.8.0 in favor of :func:`torch.distributed.nn.functional.all_gather` and
187-
will be removed in v2.0.0.
188-
"""
189-
190-
@staticmethod
191-
def forward( # type: ignore[override]
192-
ctx: Any,
193-
tensor: Tensor,
194-
group: Optional["torch.distributed.ProcessGroup"] = None,
195-
) -> Tensor:
196-
rank_zero_deprecation(
197-
"`AllGatherGrad` has been deprecated in v1.8.0 and will be removed in v2.0.0."
198-
" Use `torch.distributed.nn.functional.all_gather` instead.",
199-
stacklevel=6,
200-
)
201-
ctx.group = group
202-
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
203-
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
204-
gathered_tensor = torch.stack(gathered_tensor, dim=0)
205-
return gathered_tensor
206-
207-
@staticmethod
208-
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
209-
grad_output = torch.cat(grad_output)
210-
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
211-
return grad_output[torch.distributed.get_rank()], None
212-
213-
214179
def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any:
215180
rank_zero_deprecation(
216181
"`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated"

tests/tests_pytorch/deprecated_api/test_remove_2-0.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -303,16 +303,6 @@ def test_tuning_trainer_property():
303303
trainer.tuning = True
304304

305305

306-
@RunIf(skip_windows=True)
307-
def test_v1_8_0_deprecated_all_gather_grad():
308-
tensor1 = torch.ones(1, requires_grad=True)
309-
with mock.patch("torch.distributed.all_gather"), mock.patch("torch.distributed.get_world_size", return_value=1):
310-
from pytorch_lightning.utilities import AllGatherGrad
311-
312-
with pytest.deprecated_call(match="`AllGatherGrad` has been deprecated in v1.8"):
313-
AllGatherGrad.apply(tensor1)
314-
315-
316306
def test_v1_8_1_deprecated_rank_zero_only():
317307
from pytorch_lightning.utilities.distributed import rank_zero_only
318308

0 commit comments

Comments
 (0)