@@ -91,6 +91,12 @@ def test_multi_gpu_model_dp(tmpdir):
91
91
92
92
93
93
class ReductionTestModel (BoringModel ):
94
+ def __init__ (self ):
95
+ super ().__init__ ()
96
+ self .train_outputs = []
97
+ self .val_outputs = []
98
+ self .tests_outputs = []
99
+
94
100
def train_dataloader (self ):
95
101
return DataLoader (RandomDataset (32 , 64 ), batch_size = 2 )
96
102
@@ -111,29 +117,32 @@ def add_outputs(self, output, device):
111
117
def training_step (self , batch , batch_idx ):
112
118
output = super ().training_step (batch , batch_idx )
113
119
self .add_outputs (output , batch .device )
120
+ self .train_outputs .append (output )
114
121
return output
115
122
116
123
def validation_step (self , batch , batch_idx ):
117
124
output = super ().validation_step (batch , batch_idx )
118
125
self .add_outputs (output , batch .device )
126
+ self .val_outputs .append (output )
119
127
return output
120
128
121
129
def test_step (self , batch , batch_idx ):
122
130
output = super ().test_step (batch , batch_idx )
123
131
self .add_outputs (output , batch .device )
132
+ self .tests_outputs .append (output )
124
133
return output
125
134
126
- def training_epoch_end (self , outputs ):
127
- assert outputs [0 ]["loss" ].shape == torch .Size ([])
128
- self ._assert_extra_outputs (outputs )
135
+ def on_train_epoch_end (self ):
136
+ assert self . train_outputs [0 ]["loss" ].shape == torch .Size ([])
137
+ self ._assert_extra_outputs (self . train_outputs )
129
138
130
- def validation_epoch_end (self , outputs ):
131
- assert outputs [0 ]["x" ].shape == torch .Size ([2 ])
132
- self ._assert_extra_outputs (outputs )
139
+ def on_validation_epoch_end (self ):
140
+ assert self . val_outputs [0 ]["x" ].shape == torch .Size ([2 ])
141
+ self ._assert_extra_outputs (self . val_outputs )
133
142
134
- def test_epoch_end (self , outputs ):
135
- assert outputs [0 ]["y" ].shape == torch .Size ([2 ])
136
- self ._assert_extra_outputs (outputs )
143
+ def on_test_epoch_end (self ):
144
+ assert self . tests_outputs [0 ]["y" ].shape == torch .Size ([2 ])
145
+ self ._assert_extra_outputs (self . test_outputs )
137
146
138
147
def _assert_extra_outputs (self , outputs ):
139
148
out = outputs [0 ]["reduce_int" ]
0 commit comments