Skip to content

Commit 2f301a1

Browse files
committed
WIP
1 parent c3a9bf0 commit 2f301a1

File tree

11 files changed

+20
-274
lines changed

11 files changed

+20
-274
lines changed

src/lightning_app/utilities/introspection.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,16 +79,13 @@ class LightningModuleVisitor(LightningVisitor):
7979
"save_hyperparameters",
8080
"test_step",
8181
"test_step_end",
82-
"test_epoch_end",
8382
"to_onnx",
8483
"to_torchscript",
8584
"training_step",
8685
"training_step_end",
87-
"training_epoch_end",
8886
"unfreeze",
8987
"validation_step",
9088
"validation_step_end",
91-
"validation_epoch_end",
9289
}
9390

9491
hooks: Set[str] = {

src/pytorch_lightning/callbacks/callback.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
9595
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
9696
"""Called when the train epoch ends.
9797
98-
To access all batch outputs at the end of the epoch, either:
99-
100-
1. Implement `training_epoch_end` in the `LightningModule` and access outputs via the module OR
101-
2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.
98+
FIXME(carlos): write example
99+
To access all batch outputs at the end of the epoch, you can cache data across steps on the attribute(s) of the
100+
`LightningModule` and access them in this hook
102101
"""
103102

104103
def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:

src/pytorch_lightning/core/hooks.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,10 +169,9 @@ def on_train_epoch_start(self) -> None:
169169
def on_train_epoch_end(self) -> None:
170170
"""Called in the training loop at the very end of the epoch.
171171
172-
To access all batch outputs at the end of the epoch, either:
173-
174-
1. Implement `training_epoch_end` in the LightningModule OR
175-
2. Cache data across steps on the attribute(s) of the `LightningModule` and access them in this hook
172+
FIXME(carlos): write example
173+
To access all batch outputs at the end of the epoch, you can cache data across steps on the attribute(s) of the
174+
`LightningModule` and access them in this hook
176175
"""
177176

178177
def on_validation_epoch_start(self) -> None:

src/pytorch_lightning/core/module.py

Lines changed: 1 addition & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,7 @@
5050
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_13, _TORCHMETRICS_GREATER_EQUAL_0_9_1
5151
from pytorch_lightning.utilities.rank_zero import rank_zero_debug, rank_zero_warn, WarningCache
5252
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
53-
from pytorch_lightning.utilities.types import (
54-
_METRIC,
55-
EPOCH_OUTPUT,
56-
LRSchedulerPLType,
57-
LRSchedulerTypeUnion,
58-
STEP_OUTPUT,
59-
)
53+
from pytorch_lightning.utilities.types import _METRIC, LRSchedulerPLType, LRSchedulerTypeUnion, STEP_OUTPUT
6054

6155
warning_cache = WarningCache()
6256
log = logging.getLogger(__name__)
@@ -767,51 +761,11 @@ def training_step_end(self, training_step_outputs):
767761
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
768762
"""
769763

770-
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
771-
"""Called at the end of the training epoch with the outputs of all training steps. Use this in case you
772-
need to do something with all the outputs returned by :meth:`training_step`.
773-
774-
.. code-block:: python
775-
776-
# the pseudocode for these calls
777-
train_outs = []
778-
for train_batch in train_data:
779-
out = training_step(train_batch)
780-
train_outs.append(out)
781-
training_epoch_end(train_outs)
782-
783-
Args:
784-
outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists
785-
have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed.
786-
787-
Return:
788-
None
789-
790-
Note:
791-
If this method is not overridden, this won't be called.
792-
793-
.. code-block:: python
794-
795-
def training_epoch_end(self, training_step_outputs):
796-
# do something with all training_step outputs
797-
for out in training_step_outputs:
798-
...
799-
"""
800-
801764
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
802765
r"""
803766
Operates on a single batch of data from the validation set.
804767
In this step you'd might generate examples or calculate anything of interest like accuracy.
805768
806-
.. code-block:: python
807-
808-
# the pseudocode for these calls
809-
val_outs = []
810-
for val_batch in val_data:
811-
out = validation_step(val_batch)
812-
val_outs.append(out)
813-
validation_epoch_end(val_outs)
814-
815769
Args:
816770
batch: The output of your :class:`~torch.utils.data.DataLoader`.
817771
batch_idx: The index of this batch.
@@ -825,13 +779,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
825779
.. code-block:: python
826780
827781
# pseudocode of order
828-
val_outs = []
829782
for val_batch in val_data:
830783
out = validation_step(val_batch)
831784
if defined("validation_step_end"):
832785
out = validation_step_end(out)
833-
val_outs.append(out)
834-
val_outs = validation_epoch_end(val_outs)
835786
836787
837788
.. code-block:: python
@@ -940,65 +891,12 @@ def validation_step_end(self, val_step_outputs):
940891
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
941892
"""
942893

943-
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
944-
"""Called at the end of the validation epoch with the outputs of all validation steps.
945-
946-
.. code-block:: python
947-
948-
# the pseudocode for these calls
949-
val_outs = []
950-
for val_batch in val_data:
951-
out = validation_step(val_batch)
952-
val_outs.append(out)
953-
validation_epoch_end(val_outs)
954-
955-
Args:
956-
outputs: List of outputs you defined in :meth:`validation_step`, or if there
957-
are multiple dataloaders, a list containing a list of outputs for each dataloader.
958-
959-
Return:
960-
None
961-
962-
Note:
963-
If you didn't define a :meth:`validation_step`, this won't be called.
964-
965-
Examples:
966-
With a single dataloader:
967-
968-
.. code-block:: python
969-
970-
def validation_epoch_end(self, val_step_outputs):
971-
for out in val_step_outputs:
972-
...
973-
974-
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
975-
one entry per dataloader, while the inner list contains the individual outputs of
976-
each validation step for that dataloader.
977-
978-
.. code-block:: python
979-
980-
def validation_epoch_end(self, outputs):
981-
for dataloader_output_result in outputs:
982-
dataloader_outs = dataloader_output_result.dataloader_i_outputs
983-
984-
self.log("final_metric", final_value)
985-
"""
986-
987894
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
988895
r"""
989896
Operates on a single batch of data from the test set.
990897
In this step you'd normally generate examples or calculate anything of interest
991898
such as accuracy.
992899
993-
.. code-block:: python
994-
995-
# the pseudocode for these calls
996-
test_outs = []
997-
for test_batch in test_data:
998-
out = test_step(test_batch)
999-
test_outs.append(out)
1000-
test_epoch_end(test_outs)
1001-
1002900
Args:
1003901
batch: The output of your :class:`~torch.utils.data.DataLoader`.
1004902
batch_idx: The index of this batch.
@@ -1118,56 +1016,6 @@ def test_step_end(self, output_results):
11181016
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
11191017
"""
11201018

1121-
def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
1122-
"""Called at the end of a test epoch with the output of all test steps.
1123-
1124-
.. code-block:: python
1125-
1126-
# the pseudocode for these calls
1127-
test_outs = []
1128-
for test_batch in test_data:
1129-
out = test_step(test_batch)
1130-
test_outs.append(out)
1131-
test_epoch_end(test_outs)
1132-
1133-
Args:
1134-
outputs: List of outputs you defined in :meth:`test_step_end`, or if there
1135-
are multiple dataloaders, a list containing a list of outputs for each dataloader
1136-
1137-
Return:
1138-
None
1139-
1140-
Note:
1141-
If you didn't define a :meth:`test_step`, this won't be called.
1142-
1143-
Examples:
1144-
With a single dataloader:
1145-
1146-
.. code-block:: python
1147-
1148-
def test_epoch_end(self, outputs):
1149-
# do something with the outputs of all test batches
1150-
all_test_preds = test_step_outputs.predictions
1151-
1152-
some_result = calc_all_results(all_test_preds)
1153-
self.log(some_result)
1154-
1155-
With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
1156-
one entry per dataloader, while the inner list contains the individual outputs of
1157-
each test step for that dataloader.
1158-
1159-
.. code-block:: python
1160-
1161-
def test_epoch_end(self, outputs):
1162-
final_value = 0
1163-
for dataloader_outputs in outputs:
1164-
for test_step_out in dataloader_outputs:
1165-
# do something
1166-
final_value += test_step_out
1167-
1168-
self.log("final_metric", final_value)
1169-
"""
1170-
11711019
def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any:
11721020
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it
11731021
calls :meth:`~pytorch_lightning.core.module.LightningModule.forward`. Override to add any processing logic.

src/pytorch_lightning/demos/boring_classes.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import cast, Dict, Iterator, List, Optional, Tuple, Union
14+
from typing import Dict, Iterator, List, Optional, Tuple
1515

1616
import torch
1717
import torch.nn as nn
@@ -23,7 +23,7 @@
2323
from lightning_fabric.utilities.types import _TORCH_LRSCHEDULER
2424
from pytorch_lightning import LightningDataModule, LightningModule
2525
from pytorch_lightning.core.optimizer import LightningOptimizer
26-
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
26+
from pytorch_lightning.utilities.types import STEP_OUTPUT
2727

2828

2929
class RandomDictDataset(Dataset):
@@ -89,14 +89,14 @@ class TestModel(BoringModel):
8989
def training_step(self, ...):
9090
... # do your own thing
9191
92-
training_epoch_end = None # disable hook
92+
training_step_end = None # disable hook
9393
9494
or
9595
9696
Example::
9797
9898
model = BoringModel()
99-
model.training_epoch_end = None # disable hook
99+
model.training_step_end = None # disable hook
100100
"""
101101
super().__init__()
102102
self.layer = torch.nn.Linear(32, 2)
@@ -120,24 +120,12 @@ def training_step(self, batch: Tensor, batch_idx: int) -> STEP_OUTPUT:
120120
def training_step_end(self, training_step_outputs: STEP_OUTPUT) -> STEP_OUTPUT:
121121
return training_step_outputs
122122

123-
def training_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
124-
outputs = cast(List[Dict[str, Tensor]], outputs)
125-
torch.stack([x["loss"] for x in outputs]).mean()
126-
127123
def validation_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
128124
return {"x": self.step(batch)}
129125

130-
def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
131-
outputs = cast(List[Dict[str, Tensor]], outputs)
132-
torch.stack([x["x"] for x in outputs]).mean()
133-
134126
def test_step(self, batch: Tensor, batch_idx: int) -> Optional[STEP_OUTPUT]:
135127
return {"y": self.step(batch)}
136128

137-
def test_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
138-
outputs = cast(List[Dict[str, Tensor]], outputs)
139-
torch.stack([x["y"] for x in outputs]).mean()
140-
141129
def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List[_TORCH_LRSCHEDULER]]:
142130
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
143131
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

src/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataLoaderIterDataFetcher
3131
from pytorch_lightning.utilities.imports import _fault_tolerant_training
3232
from pytorch_lightning.utilities.model_helpers import is_overridden
33-
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
33+
from pytorch_lightning.utilities.types import STEP_OUTPUT
3434

3535

3636
class _EvaluationEpochLoop(_Loop):
@@ -44,7 +44,6 @@ def __init__(self) -> None:
4444
super().__init__()
4545
self.batch_progress = BatchProgress()
4646

47-
self._outputs: EPOCH_OUTPUT = []
4847
self._dl_max_batches: Union[int, float] = 0
4948
self._data_fetcher: Optional[AbstractDataFetcher] = None
5049
self._dataloader_state_dict: Dict[str, Any] = {}
@@ -55,9 +54,7 @@ def done(self) -> bool:
5554
"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""
5655
return self.batch_progress.current.completed >= self._dl_max_batches
5756

58-
def run(
59-
self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict
60-
) -> EPOCH_OUTPUT:
57+
def run(self, data_fetcher: AbstractDataFetcher, dl_max_batches: Union[int, float], kwargs: OrderedDict) -> None:
6158
self.reset()
6259
self.on_run_start(data_fetcher, dl_max_batches, kwargs)
6360
while not self.done:
@@ -67,7 +64,7 @@ def run(
6764
except StopIteration:
6865
break
6966
self._restarting = False
70-
return self.on_run_end()
67+
self.on_run_end()
7168

7269
def reset(self) -> None:
7370
"""Resets the loop's internal state."""
@@ -172,11 +169,8 @@ def advance(
172169
# if fault tolerant is enabled and process has been notified, exit.
173170
self.trainer._exit_gracefully_on_signal()
174171

175-
def on_run_end(self) -> EPOCH_OUTPUT:
176-
"""Returns the outputs of the whole run."""
177-
outputs, self._outputs = self._outputs, [] # free memory
172+
def on_run_end(self):
178173
self._data_fetcher = None
179-
return outputs
180174

181175
def teardown(self) -> None:
182176
# in case the model changes

0 commit comments

Comments
 (0)