Skip to content

Commit cc1d2ca

Browse files
committed
dp fix?
1 parent 26bf03f commit cc1d2ca

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

tests/tests_pytorch/strategies/test_dp.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ def test_multi_gpu_model_dp(tmpdir):
9191

9292

9393
class ReductionTestModel(BoringModel):
94+
def __init__(self):
95+
super().__init__()
96+
self.train_outputs = []
97+
self.val_outputs = []
98+
self.tests_outputs = []
99+
94100
def train_dataloader(self):
95101
return DataLoader(RandomDataset(32, 64), batch_size=2)
96102

@@ -111,29 +117,32 @@ def add_outputs(self, output, device):
111117
def training_step(self, batch, batch_idx):
112118
output = super().training_step(batch, batch_idx)
113119
self.add_outputs(output, batch.device)
120+
self.train_outputs.append(output)
114121
return output
115122

116123
def validation_step(self, batch, batch_idx):
117124
output = super().validation_step(batch, batch_idx)
118125
self.add_outputs(output, batch.device)
126+
self.val_outputs.append(output)
119127
return output
120128

121129
def test_step(self, batch, batch_idx):
122130
output = super().test_step(batch, batch_idx)
123131
self.add_outputs(output, batch.device)
132+
self.tests_outputs.append(output)
124133
return output
125134

126-
def training_epoch_end(self, outputs):
127-
assert outputs[0]["loss"].shape == torch.Size([])
128-
self._assert_extra_outputs(outputs)
135+
def on_train_epoch_end(self):
136+
assert self.train_outputs[0]["loss"].shape == torch.Size([])
137+
self._assert_extra_outputs(self.train_outputs)
129138

130-
def validation_epoch_end(self, outputs):
131-
assert outputs[0]["x"].shape == torch.Size([2])
132-
self._assert_extra_outputs(outputs)
139+
def on_validation_epoch_end(self):
140+
assert self.val_outputs[0]["x"].shape == torch.Size([2])
141+
self._assert_extra_outputs(self.val_outputs)
133142

134-
def test_epoch_end(self, outputs):
135-
assert outputs[0]["y"].shape == torch.Size([2])
136-
self._assert_extra_outputs(outputs)
143+
def on_test_epoch_end(self):
144+
assert self.tests_outputs[0]["y"].shape == torch.Size([2])
145+
self._assert_extra_outputs(self.test_outputs)
137146

138147
def _assert_extra_outputs(self, outputs):
139148
out = outputs[0]["reduce_int"]

0 commit comments

Comments
 (0)