11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- from typing import Any , Optional , Union
14
+ from typing import Any , Union
15
15
16
16
import torch
17
- import torch .nn as nn
18
- from torch .nn import DataParallel
19
- from torch .nn .parallel import DistributedDataParallel
20
17
21
18
import pytorch_lightning as pl
22
19
from lightning_fabric .utilities .device_dtype_mixin import _DeviceDtypeModuleMixin
23
- from pytorch_lightning .utilities .rank_zero import rank_zero_deprecation
24
20
25
21
26
22
class _LightningPrecisionModuleWrapperBase (_DeviceDtypeModuleMixin , torch .nn .Module ):
@@ -55,9 +51,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
55
51
56
52
57
53
class _LightningModuleWrapperBase (_DeviceDtypeModuleMixin , torch .nn .Module ):
58
- def __init__ (
59
- self , forward_module : Optional [Union ["pl.LightningModule" , _LightningPrecisionModuleWrapperBase ]]
60
- ) -> None :
54
+ def __init__ (self , forward_module : Union ["pl.LightningModule" , _LightningPrecisionModuleWrapperBase ]) -> None :
61
55
"""Wraps the user's LightningModule and redirects the forward call to the appropriate method, either
62
56
``training_step``, ``validation_step``, ``test_step``, or ``predict_step``.
63
57
@@ -75,8 +69,6 @@ def __init__(
75
69
"`forward_module` must be a `LightningModule` instance or have an attribute `.module` pointing to one,"
76
70
f" got: { forward_module .__class__ .__qualname__ } "
77
71
)
78
- # TODO: In v2.0.0, remove the Optional type from forward_module and remove the assertion
79
- assert forward_module is not None
80
72
self ._forward_module = forward_module
81
73
82
74
# set the parameters_to_ignore from LightningModule.
@@ -111,47 +103,3 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
111
103
if trainer .predicting :
112
104
return self ._forward_module .predict_step (* inputs , ** kwargs )
113
105
return self ._forward_module (* inputs , ** kwargs )
114
-
115
- @classmethod
116
- def _validate_init_arguments (
117
- cls ,
118
- pl_module : Optional [Union ["pl.LightningModule" , _LightningPrecisionModuleWrapperBase ]] = None ,
119
- forward_module : Optional [Union ["pl.LightningModule" , _LightningPrecisionModuleWrapperBase ]] = None ,
120
- ) -> None :
121
- # TODO: In v2.0.0, remove this method and mark the forward_module init argument in all subclasses as required
122
- if pl_module is not None :
123
- rank_zero_deprecation (
124
- f"The argument `pl_module` in `{ cls .__name__ } ` is deprecated in v1.8.0 and will be removed in"
125
- " v2.0.0. Please use `forward_module` instead."
126
- )
127
- elif forward_module is None :
128
- raise ValueError ("Argument `forward_module` is required." )
129
-
130
-
131
- def unwrap_lightning_module (wrapped_model : nn .Module , _suppress_warning : bool = False ) -> "pl.LightningModule" :
132
- """Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``
133
- attributes on the wrapper.
134
-
135
- .. deprecated:: v1.8.0
136
- The function ``unwrap_lightning_module`` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the
137
- ``LightningModule`` directly through the strategy attribute ``Strategy.lightning_module``.
138
-
139
- Raises:
140
- TypeError: If the unwrapping leads to a module that is not a LightningModule and that cannot be unwrapped
141
- further.
142
- """
143
- if not _suppress_warning :
144
- rank_zero_deprecation (
145
- "The function `unwrap_lightning_module` is deprecated in v1.8.0 and will be removed in v2.0.0. Access the"
146
- " `LightningModule` directly through the strategy attribute `Strategy.lightning_module`."
147
- )
148
- model = wrapped_model
149
- if isinstance (model , (DistributedDataParallel , DataParallel )):
150
- model = unwrap_lightning_module (model .module )
151
- if isinstance (model , _LightningModuleWrapperBase ):
152
- model = model .lightning_module
153
- if isinstance (model , _LightningPrecisionModuleWrapperBase ):
154
- model = model .module
155
- if not isinstance (model , pl .LightningModule ):
156
- raise TypeError (f"Unwrapping the module did not yield a `LightningModule`, got { type (model )} instead." )
157
- return model
0 commit comments