|
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 |
| -import inspect |
15 | 14 | from abc import ABC, abstractmethod
|
16 |
| -from typing import Any, Dict, Generic, Optional, Type, TypeVar, Union |
| 15 | +from typing import Any, Dict, Generic, Optional, TypeVar |
17 | 16 |
|
18 | 17 | from torchmetrics import Metric
|
19 | 18 |
|
20 | 19 | import pytorch_lightning as pl
|
21 | 20 | from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
22 | 21 | from pytorch_lightning.trainer.progress import BaseProgress
|
23 |
| -from pytorch_lightning.utilities.exceptions import MisconfigurationException |
24 | 22 | from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
25 | 23 |
|
26 | 24 | T = TypeVar("T") # the output type of `run`
|
@@ -108,52 +106,6 @@ def connect(self, **kwargs: "Loop") -> None:
|
108 | 106 | Linked loops should form a tree.
|
109 | 107 | """
|
110 | 108 |
|
111 |
| - def replace(self, **loops: Union["Loop", Type["Loop"]]) -> None: |
112 |
| - """Optionally replace one or multiple of this loop's sub-loops. |
113 |
| -
|
114 |
| - This method takes care of instantiating the class (if necessary) with all existing arguments, connecting all |
115 |
| - sub-loops of the old loop to the new instance, setting the ``Trainer`` reference, and connecting the new loop to |
116 |
| - the parent. |
117 |
| -
|
118 |
| - Args: |
119 |
| - **loops: ``Loop`` subclasses or instances. The name used should match the loop attribute name you want to |
120 |
| - replace. |
121 |
| -
|
122 |
| - Raises: |
123 |
| - MisconfigurationException: When passing a ``Loop`` class, if the ``__init__`` arguments do not match those |
124 |
| - of the Loop class it replaces. |
125 |
| - """ |
126 |
| - new_loops = {} |
127 |
| - |
128 |
| - for name, type_or_object in loops.items(): |
129 |
| - old_loop = getattr(self, name) |
130 |
| - |
131 |
| - if isinstance(type_or_object, type): |
132 |
| - # compare the signatures |
133 |
| - old_parameters = inspect.signature(old_loop.__class__.__init__).parameters |
134 |
| - current_parameters = inspect.signature(type_or_object.__init__).parameters |
135 |
| - if old_parameters != current_parameters: |
136 |
| - raise MisconfigurationException( |
137 |
| - f"`{self.__class__.__name__}.replace({type_or_object.__name__})` can only be used if the" |
138 |
| - f" `__init__` signatures match but `{old_loop.__class__.__name__}` does not." |
139 |
| - ) |
140 |
| - # instantiate the loop |
141 |
| - kwargs = {p: getattr(old_loop, p) for p in old_parameters if p != "self"} |
142 |
| - loop = type_or_object(**kwargs) |
143 |
| - else: |
144 |
| - loop = type_or_object |
145 |
| - |
146 |
| - # connect sub-loops |
147 |
| - kwargs = {n: l for n, l in old_loop.__dict__.items() if isinstance(l, Loop)} |
148 |
| - if kwargs: |
149 |
| - loop.connect(**kwargs) |
150 |
| - # set the trainer reference |
151 |
| - loop.trainer = self.trainer |
152 |
| - |
153 |
| - new_loops[name] = loop |
154 |
| - # connect to self |
155 |
| - self.connect(**new_loops) |
156 |
| - |
157 | 109 | def on_skip(self) -> T:
|
158 | 110 | """The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
|
159 | 111 |
|
|
0 commit comments