50
50
from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_13 , _TORCHMETRICS_GREATER_EQUAL_0_9_1
51
51
from pytorch_lightning .utilities .rank_zero import rank_zero_debug , rank_zero_warn , WarningCache
52
52
from pytorch_lightning .utilities .signature_utils import is_param_in_hook_signature
53
- from pytorch_lightning .utilities .types import (
54
- _METRIC ,
55
- EPOCH_OUTPUT ,
56
- LRSchedulerPLType ,
57
- LRSchedulerTypeUnion ,
58
- STEP_OUTPUT ,
59
- )
53
+ from pytorch_lightning .utilities .types import _METRIC , LRSchedulerPLType , LRSchedulerTypeUnion , STEP_OUTPUT
60
54
61
55
warning_cache = WarningCache ()
62
56
log = logging .getLogger (__name__ )
@@ -767,51 +761,11 @@ def training_step_end(self, training_step_outputs):
767
761
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
768
762
"""
769
763
770
- def training_epoch_end (self , outputs : EPOCH_OUTPUT ) -> None :
771
- """Called at the end of the training epoch with the outputs of all training steps. Use this in case you
772
- need to do something with all the outputs returned by :meth:`training_step`.
773
-
774
- .. code-block:: python
775
-
776
- # the pseudocode for these calls
777
- train_outs = []
778
- for train_batch in train_data:
779
- out = training_step(train_batch)
780
- train_outs.append(out)
781
- training_epoch_end(train_outs)
782
-
783
- Args:
784
- outputs: List of outputs you defined in :meth:`training_step`. If there are multiple optimizers, the lists
785
- have the dimensions (n_batches, n_optimizers). Dimensions of length 1 are squeezed.
786
-
787
- Return:
788
- None
789
-
790
- Note:
791
- If this method is not overridden, this won't be called.
792
-
793
- .. code-block:: python
794
-
795
- def training_epoch_end(self, training_step_outputs):
796
- # do something with all training_step outputs
797
- for out in training_step_outputs:
798
- ...
799
- """
800
-
801
764
def validation_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
802
765
r"""
803
766
Operates on a single batch of data from the validation set.
804
767
In this step you'd might generate examples or calculate anything of interest like accuracy.
805
768
806
- .. code-block:: python
807
-
808
- # the pseudocode for these calls
809
- val_outs = []
810
- for val_batch in val_data:
811
- out = validation_step(val_batch)
812
- val_outs.append(out)
813
- validation_epoch_end(val_outs)
814
-
815
769
Args:
816
770
batch: The output of your :class:`~torch.utils.data.DataLoader`.
817
771
batch_idx: The index of this batch.
@@ -825,13 +779,10 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
825
779
.. code-block:: python
826
780
827
781
# pseudocode of order
828
- val_outs = []
829
782
for val_batch in val_data:
830
783
out = validation_step(val_batch)
831
784
if defined("validation_step_end"):
832
785
out = validation_step_end(out)
833
- val_outs.append(out)
834
- val_outs = validation_epoch_end(val_outs)
835
786
836
787
837
788
.. code-block:: python
@@ -940,65 +891,12 @@ def validation_step_end(self, val_step_outputs):
940
891
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
941
892
"""
942
893
943
- def validation_epoch_end (self , outputs : Union [EPOCH_OUTPUT , List [EPOCH_OUTPUT ]]) -> None :
944
- """Called at the end of the validation epoch with the outputs of all validation steps.
945
-
946
- .. code-block:: python
947
-
948
- # the pseudocode for these calls
949
- val_outs = []
950
- for val_batch in val_data:
951
- out = validation_step(val_batch)
952
- val_outs.append(out)
953
- validation_epoch_end(val_outs)
954
-
955
- Args:
956
- outputs: List of outputs you defined in :meth:`validation_step`, or if there
957
- are multiple dataloaders, a list containing a list of outputs for each dataloader.
958
-
959
- Return:
960
- None
961
-
962
- Note:
963
- If you didn't define a :meth:`validation_step`, this won't be called.
964
-
965
- Examples:
966
- With a single dataloader:
967
-
968
- .. code-block:: python
969
-
970
- def validation_epoch_end(self, val_step_outputs):
971
- for out in val_step_outputs:
972
- ...
973
-
974
- With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
975
- one entry per dataloader, while the inner list contains the individual outputs of
976
- each validation step for that dataloader.
977
-
978
- .. code-block:: python
979
-
980
- def validation_epoch_end(self, outputs):
981
- for dataloader_output_result in outputs:
982
- dataloader_outs = dataloader_output_result.dataloader_i_outputs
983
-
984
- self.log("final_metric", final_value)
985
- """
986
-
987
894
def test_step (self , * args : Any , ** kwargs : Any ) -> Optional [STEP_OUTPUT ]:
988
895
r"""
989
896
Operates on a single batch of data from the test set.
990
897
In this step you'd normally generate examples or calculate anything of interest
991
898
such as accuracy.
992
899
993
- .. code-block:: python
994
-
995
- # the pseudocode for these calls
996
- test_outs = []
997
- for test_batch in test_data:
998
- out = test_step(test_batch)
999
- test_outs.append(out)
1000
- test_epoch_end(test_outs)
1001
-
1002
900
Args:
1003
901
batch: The output of your :class:`~torch.utils.data.DataLoader`.
1004
902
batch_idx: The index of this batch.
@@ -1118,56 +1016,6 @@ def test_step_end(self, output_results):
1118
1016
See the :ref:`Multi GPU Training <gpu_intermediate>` guide for more details.
1119
1017
"""
1120
1018
1121
- def test_epoch_end (self , outputs : Union [EPOCH_OUTPUT , List [EPOCH_OUTPUT ]]) -> None :
1122
- """Called at the end of a test epoch with the output of all test steps.
1123
-
1124
- .. code-block:: python
1125
-
1126
- # the pseudocode for these calls
1127
- test_outs = []
1128
- for test_batch in test_data:
1129
- out = test_step(test_batch)
1130
- test_outs.append(out)
1131
- test_epoch_end(test_outs)
1132
-
1133
- Args:
1134
- outputs: List of outputs you defined in :meth:`test_step_end`, or if there
1135
- are multiple dataloaders, a list containing a list of outputs for each dataloader
1136
-
1137
- Return:
1138
- None
1139
-
1140
- Note:
1141
- If you didn't define a :meth:`test_step`, this won't be called.
1142
-
1143
- Examples:
1144
- With a single dataloader:
1145
-
1146
- .. code-block:: python
1147
-
1148
- def test_epoch_end(self, outputs):
1149
- # do something with the outputs of all test batches
1150
- all_test_preds = test_step_outputs.predictions
1151
-
1152
- some_result = calc_all_results(all_test_preds)
1153
- self.log(some_result)
1154
-
1155
- With multiple dataloaders, `outputs` will be a list of lists. The outer list contains
1156
- one entry per dataloader, while the inner list contains the individual outputs of
1157
- each test step for that dataloader.
1158
-
1159
- .. code-block:: python
1160
-
1161
- def test_epoch_end(self, outputs):
1162
- final_value = 0
1163
- for dataloader_outputs in outputs:
1164
- for test_step_out in dataloader_outputs:
1165
- # do something
1166
- final_value += test_step_out
1167
-
1168
- self.log("final_metric", final_value)
1169
- """
1170
-
1171
1019
def predict_step (self , batch : Any , batch_idx : int , dataloader_idx : int = 0 ) -> Any :
1172
1020
"""Step function called during :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. By default, it
1173
1021
calls :meth:`~pytorch_lightning.core.module.LightningModule.forward`. Override to add any processing logic.
0 commit comments