Skip to content

Commit 4c90e66

Browse files
authored
Merge branch 'lite/debug' into lite/debug-distributed
2 parents ce2ffe0 + 8a29ab7 commit 4c90e66

File tree

14 files changed

+26
-184
lines changed

14 files changed

+26
-184
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5151

5252
- Removed the deprecated code in `pl.utilities.distributed` ([#16390](https://github.com/Lightning-AI/lightning/pull/16390))
5353

54+
- Mark the `forward_module` argument as required ([#16386](https://github.com/Lightning-AI/lightning/pull/16386))
55+
* Removed the deprecated `pl_module` argument from the distributed module wrappers
56+
* Removed the deprecated `pytorch_lightning.overrides.base.unwrap_lightning_module` function
57+
* Removed the `pytorch_lightning.overrides.distributed.LightningDistributedModule` class
58+
* Removed the deprecated `pytorch_lightning.overrides.fairscale.unwrap_lightning_module_sharded` function
59+
* Removed the `pytorch_lightning.overrides.fairscale.LightningDistributedModule` class
60+
5461
- Removed the deprecated automatic GPU selection ([#16184](https://github.com/Lightning-AI/lightning/pull/16184))
5562
* Removed the `Trainer(auto_select_gpus=...)` argument
5663
* Removed the `pytorch_lightning.tuner.auto_gpu_select.{pick_single_gpu,pick_multiple_gpus}` functions
Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
2-
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401

src/pytorch_lightning/overrides/base.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional, Union
14+
from typing import Any, Union
1515

1616
import torch
17-
import torch.nn as nn
18-
from torch.nn import DataParallel
19-
from torch.nn.parallel import DistributedDataParallel
2017

2118
import pytorch_lightning as pl
2219
from lightning_fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
23-
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2420

2521

2622
class _LightningPrecisionModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
@@ -55,9 +51,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
5551

5652

5753
class _LightningModuleWrapperBase(_DeviceDtypeModuleMixin, torch.nn.Module):
58-
def __init__(
59-
self, forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]]
60-
) -> None:
54+
def __init__(self, forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
6155
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
6256
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
6357
@@ -75,8 +69,6 @@ def __init__(
7569
"`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one,"
7670
f" got: {forward_module.__class__.__qualname__}"
7771
)
78-
# TODO: In v2.0.0, remove the Optional type from forward_module and remove the assertion
79-
assert forward_module is not None
8072
self._forward_module = forward_module
8173

8274
# set the parameters_to_ignore from LightningModule.
@@ -111,47 +103,3 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
111103
if trainer.predicting:
112104
return self._forward_module.predict_step(*inputs, **kwargs)
113105
return self._forward_module(*inputs, **kwargs)
114-
115-
@classmethod
116-
def _validate_init_arguments(
117-
cls,
118-
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
119-
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
120-
) -> None:
121-
# TODO: In v2.0.0, remove this method and mark the forward_module init argument in all subclasses as required
122-
if pl_module is not None:
123-
rank_zero_deprecation(
124-
f"The argument `pl_module` in `{cls.__name__}` is deprecated in v1.8.0 and will be removed in"
125-
" v2.0.0. Please use `forward_module` instead."
126-
)
127-
elif forward_module is None:
128-
raise ValueError("Argument `forward_module` is required.")
129-
130-
131-
def unwrap_lightning_module(wrapped_model: nn.Module, _suppress_warning: bool = False) -> "pl.LightningModule":
132-
"""Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``
133-
attributes on the wrapper.
134-
135-
.. deprecated:: v1.8.0
136-
The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the
137-
``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``.
138-
139-
Raises:
140-
TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
141-
further.
142-
"""
143-
if not _suppress_warning:
144-
rank_zero_deprecation(
145-
"The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the"
146-
" `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
147-
)
148-
model = wrapped_model
149-
if isinstance(model, (DistributedDataParallel, DataParallel)):
150-
model = unwrap_lightning_module(model.module)
151-
if isinstance(model, _LightningModuleWrapperBase):
152-
model = model.lightning_module
153-
if isinstance(model, _LightningPrecisionModuleWrapperBase):
154-
model = model.module
155-
if not isinstance(model, pl.LightningModule):
156-
raise TypeError(f"Unwrapping the module did not yield a `LightningModule`, got {type(model)} instead.")
157-
return model

src/pytorch_lightning/overrides/data_parallel.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import numbers
1515
import warnings
16-
from typing import Any, Optional, Union
16+
from typing import Any, Union
1717

1818
import torch
1919
from lightning_utilities.core.apply_func import apply_to_collection
@@ -52,23 +52,15 @@ class LightningParallelModule(_LightningModuleWrapperBase):
5252
)
5353
5454
Args:
55-
pl_module: The module to wrap. See description for `forward_module`.
56-
57-
.. deprecated:: v1.8.0
58-
The argument ``pl_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Please use
59-
``forward_module`` instead.
60-
6155
forward_module: The module to wrap. If it's not a ``LightningModule``, it must have an attribute ``.module``
6256
pointing to a ``LightningModule`` reference.
6357
"""
6458

6559
def __init__(
6660
self,
67-
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
68-
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
61+
forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase],
6962
) -> None:
70-
self._validate_init_arguments(pl_module, forward_module)
71-
super().__init__(forward_module=(pl_module or forward_module))
63+
super().__init__(forward_module=forward_module)
7264
_ignore_scalar_return_in_dp()
7365

7466
def forward(self, *inputs: Any, **kwargs: Any) -> Any:

src/pytorch_lightning/overrides/distributed.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12,26 +12,14 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import itertools
15-
from typing import Any, cast, Iterable, Iterator, List, Optional, Sized, Union
15+
from typing import Any, cast, Iterable, Iterator, List, Sized, Union
1616

1717
import torch
1818
from torch import Tensor
1919
from torch.nn.parallel import DistributedDataParallel
2020
from torch.utils.data import BatchSampler, DistributedSampler, Sampler
2121

22-
import pytorch_lightning as pl
2322
from lightning_fabric.utilities.distributed import _DatasetSamplerWrapper
24-
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
25-
26-
27-
class LightningDistributedModule(_LightningModuleWrapperBase):
28-
def __init__(
29-
self,
30-
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
31-
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
32-
) -> None:
33-
self._validate_init_arguments(pl_module, forward_module)
34-
super().__init__(forward_module=(pl_module or forward_module))
3523

3624

3725
def _find_tensors(

src/pytorch_lightning/overrides/fairscale.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,13 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import List, Optional, Union
14+
from typing import List
1515

16-
import torch.nn as nn
1716
from lightning_utilities.core.imports import package_available
1817
from torch.optim import Optimizer
1918

20-
import pytorch_lightning as pl
2119
from lightning_fabric.plugins import Precision
2220
from lightning_fabric.utilities.imports import _IS_WINDOWS
23-
from pytorch_lightning.overrides.base import (
24-
_LightningModuleWrapperBase,
25-
_LightningPrecisionModuleWrapperBase,
26-
unwrap_lightning_module,
27-
)
28-
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2921

3022
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and package_available("fairscale")
3123

@@ -35,30 +27,6 @@
3527
OSS = object
3628

3729

38-
class LightningShardedDataParallel(_LightningModuleWrapperBase):
39-
def __init__(
40-
self,
41-
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
42-
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
43-
) -> None:
44-
self._validate_init_arguments(pl_module, forward_module)
45-
super().__init__(forward_module=(pl_module or forward_module))
46-
47-
48-
def unwrap_lightning_module_sharded(wrapped_model: nn.Module) -> "pl.LightningModule":
49-
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
50-
51-
rank_zero_deprecation(
52-
"The function `unwrap_lightning_module_sharded` is deprecated in v1.8.0 and will be removed in v2.0.0."
53-
" Access the `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
54-
)
55-
model = wrapped_model
56-
if isinstance(model, ShardedDataParallel):
57-
model = model.module
58-
59-
return unwrap_lightning_module(model, _suppress_warning=True)
60-
61-
6230
def _reinit_optimizers_with_oss(optimizers: List[Optimizer], precision: Precision, num_nodes: int) -> List["OSS"]:
6331
for x, optimizer in enumerate(optimizers):
6432
if not isinstance(optimizer, OSS):

src/pytorch_lightning/strategies/bagua.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,8 @@
6464
class LightningBaguaModule(_LightningModuleWrapperBase):
6565
def __init__(
6666
self,
67-
forward_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
68-
pl_module: Optional[Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]] = None,
67+
forward_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase],
6968
) -> None:
70-
self._validate_init_arguments(pl_module, forward_module)
71-
forward_module = pl_module or forward_module
7269
super().__init__(forward_module=forward_module)
7370
# Bagua use `bagua_module_name` to distinguish different modules
7471
self._bagua_module_name = f"{forward_module.__class__.__name__}{id(forward_module)}"

src/pytorch_lightning/strategies/ddp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@
4343
from lightning_fabric.utilities.seed import reset_seed
4444
from lightning_fabric.utilities.types import ReduceOp
4545
from pytorch_lightning.core.optimizer import LightningOptimizer
46-
from pytorch_lightning.overrides import LightningDistributedModule
47-
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
46+
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
4847
from pytorch_lightning.overrides.distributed import prepare_for_backward
4948
from pytorch_lightning.overrides.fairscale import _FAIRSCALE_AVAILABLE
5049
from pytorch_lightning.plugins.precision import PrecisionPlugin
@@ -291,7 +290,7 @@ def configure_ddp(self) -> None:
291290
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
292291
self.pre_configure_ddp()
293292
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
294-
self.model = self._setup_model(LightningDistributedModule(self.model))
293+
self.model = self._setup_model(_LightningModuleWrapperBase(self.model))
295294
self._register_ddp_hooks()
296295

297296
def determine_ddp_device_ids(self) -> Optional[List[int]]:

src/pytorch_lightning/strategies/ddp_spawn.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@
3636
from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_11
3737
from lightning_fabric.utilities.optimizer import _optimizers_to_device
3838
from lightning_fabric.utilities.types import ReduceOp
39-
from pytorch_lightning.overrides import LightningDistributedModule
40-
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
39+
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
4140
from pytorch_lightning.overrides.distributed import prepare_for_backward
4241
from pytorch_lightning.plugins.precision import PrecisionPlugin
4342
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
@@ -211,7 +210,7 @@ def _register_ddp_hooks(self) -> None:
211210
def configure_ddp(self) -> None:
212211
self.pre_configure_ddp()
213212
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
214-
self.model = self._setup_model(LightningDistributedModule(self.model))
213+
self.model = self._setup_model(_LightningModuleWrapperBase(self.model))
215214
self._register_ddp_hooks()
216215

217216
# set up optimizers after the wrapped module has been moved to the device

src/pytorch_lightning/strategies/hpu_parallel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import pytorch_lightning as pl
2323
from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment
2424
from lightning_fabric.utilities.distributed import group as _group
25-
from pytorch_lightning.overrides import LightningDistributedModule
25+
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase
2626
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
2727
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
2828
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
@@ -123,7 +123,7 @@ def configure_ddp(self) -> None:
123123
if _TORCH_LESSER_EQUAL_1_10_2:
124124
log.detail(f"{self.__class__.__name__}: configuring DistributedDataParallel")
125125
self._pre_configure_ddp()
126-
self.model = self._setup_model(LightningDistributedModule(self.model)) # type: ignore
126+
self.model = self._setup_model(_LightningModuleWrapperBase(self.model)) # type: ignore
127127
if self.root_device.type == "hpu" and self._static_graph:
128128
self._model._set_static_graph() # type: ignore
129129
self._register_ddp_hooks()

0 commit comments

Comments
 (0)