Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Removed the `Trainer(ipus=...)` argument
* Removed the `Trainer(num_processes=...)` argument

- Removed the deprecated `pytorch_lightning.utilities.AllGatherGrad` class ([#16360](https://github.com/Lightning-AI/lightning/pull/16360))

- Removed the deprecated `resume_from_checkpoint` Trainer argument ([#16167](https://github.com/Lightning-AI/lightning/pull/16167))

- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))
Expand Down
1 change: 0 additions & 1 deletion src/pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from lightning_fabric.utilities import LightningEnum # noqa: F401
from lightning_fabric.utilities import move_data_to_device # noqa: F401
from pytorch_lightning.utilities.distributed import AllGatherGrad # noqa: F401
from pytorch_lightning.utilities.enums import GradClipAlgorithmType # noqa: F401
from pytorch_lightning.utilities.grads import grad_norm # noqa: F401
from pytorch_lightning.utilities.imports import ( # noqa: F401
Expand Down
37 changes: 1 addition & 36 deletions src/pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# limitations under the License.
"""Utilities that can be used with distributed training."""

from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional

import torch
from torch import Tensor
from torch.nn.parallel.distributed import DistributedDataParallel

from lightning_fabric.utilities.distributed import _all_gather_ddp_if_available as new_all_gather_ddp_if_available
Expand Down Expand Up @@ -177,40 +176,6 @@ def gather_all_tensors(*args: Any, **kwargs: Any) -> Any:
return new_gather_all_tensors(*args, **kwargs)


class AllGatherGrad(torch.autograd.Function):
"""Gathers tensors from the whole group and stacks them.

This implementation is copied from PyTorch.

.. deprecated:: v1.8.0
This function has been deprecated in v1.8.0 in favor of :func:`torch.distributed.nn.functional.all_gather` and
will be removed in v2.0.0.
"""

@staticmethod
def forward( # type: ignore[override]
ctx: Any,
tensor: Tensor,
group: Optional["torch.distributed.ProcessGroup"] = None,
) -> Tensor:
rank_zero_deprecation(
"`AllGatherGrad` has been deprecated in v1.8.0 and will be removed in v2.0.0."
" Use `torch.distributed.nn.functional.all_gather` instead.",
stacklevel=6,
)
ctx.group = group
gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(gathered_tensor, tensor, group=group)
gathered_tensor = torch.stack(gathered_tensor, dim=0)
return gathered_tensor

@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor, None]:
grad_output = torch.cat(grad_output)
torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
return grad_output[torch.distributed.get_rank()], None


def get_default_process_group_backend_for_device(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"`pytorch_lightning.utilities.distributed.get_default_process_group_backend_for_device` has been deprecated"
Expand Down
10 changes: 0 additions & 10 deletions tests/tests_pytorch/deprecated_api/test_remove_2-0.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,16 +303,6 @@ def test_tuning_trainer_property():
trainer.tuning = True


@RunIf(skip_windows=True)
def test_v1_8_0_deprecated_all_gather_grad():
tensor1 = torch.ones(1, requires_grad=True)
with mock.patch("torch.distributed.all_gather"), mock.patch("torch.distributed.get_world_size", return_value=1):
from pytorch_lightning.utilities import AllGatherGrad

with pytest.deprecated_call(match="`AllGatherGrad` has been deprecated in v1.8"):
AllGatherGrad.apply(tensor1)


def test_v1_8_1_deprecated_rank_zero_only():
from pytorch_lightning.utilities.distributed import rank_zero_only

Expand Down