Skip to content

Commit 523f44f

Browse files
committed
Keep the output type
1 parent 3a8b3fc commit 523f44f

File tree

10 files changed

+48
-85
lines changed

10 files changed

+48
-85
lines changed

pytorch_lightning/loops/optimization/closure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
1515
from dataclasses import dataclass
16-
from typing import Any, Dict, Generic, Optional, TypeVar
16+
from typing import Any, Generic, Optional, TypeVar
1717

1818
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1919

@@ -22,7 +22,7 @@
2222

2323
@dataclass
2424
class OutputResult:
25-
def asdict(self) -> Dict[str, Any]:
25+
def get(self) -> Any:
2626
raise NotImplementedError
2727

2828

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
15-
from typing import Any, Dict, Optional
15+
from typing import Any, Dict, Optional, Union
1616

1717
from torch import Tensor
1818

@@ -31,15 +31,19 @@ class ManualResult(OutputResult):
3131
3232
Attributes:
3333
extra: Anything returned by the ``training_step``.
34+
was_dict: Whether the training step output was a dictionary.
3435
"""
3536

3637
extra: Dict[str, Any] = field(default_factory=dict)
38+
was_dict: bool = False
3739

3840
@classmethod
3941
def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT]) -> "ManualResult":
40-
extra = {}
42+
extra, was_dict = {}, False
43+
4144
if isinstance(training_step_output, dict):
4245
extra = {k: v for k, v in training_step_output.items() if k != "hiddens"}
46+
was_dict = True
4347
elif isinstance(training_step_output, Tensor):
4448
extra = {"loss": training_step_output}
4549
elif training_step_output is not None:
@@ -52,13 +56,15 @@ def from_training_step_output(cls, training_step_output: Optional[STEP_OUTPUT])
5256
# we detach manually as it's expected that it will have a `grad_fn`
5357
extra["loss"] = extra["loss"].detach()
5458

55-
return cls(extra=extra)
59+
return cls(extra=extra, was_dict=was_dict)
5660

57-
def asdict(self) -> Dict[str, Any]:
58-
return self.extra
61+
def get(self) -> Union[Optional[Tensor], Dict[str, Any]]:
62+
if self.was_dict:
63+
return self.extra
64+
return self.extra.get("loss")
5965

6066

61-
_OUTPUTS_TYPE = Dict[str, Any]
67+
_OUTPUTS_TYPE = Union[Optional[Tensor], Dict[str, Any]]
6268

6369

6470
class ManualOptimization(Loop[_OUTPUTS_TYPE]):
@@ -122,7 +128,7 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
122128
self.trainer._results.cpu()
123129

124130
self._done = True
125-
self._output = result.asdict()
131+
self._output = result.get()
126132

127133
def on_run_end(self) -> _OUTPUTS_TYPE:
128134
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
from dataclasses import dataclass, field
1515
from functools import partial
16-
from typing import Any, Callable, Dict, List, Optional, Tuple
16+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1717

1818
import torch
1919
from torch import Tensor
@@ -48,11 +48,13 @@ class ClosureResult(OutputResult):
4848
closure_loss: The loss with a graph attached.
4949
loss: A detached copy of the closure loss.
5050
extra: Any keys other than the loss returned.
51+
was_dict: Whether the training step output was a dictionary.
5152
"""
5253

5354
closure_loss: Optional[Tensor]
5455
loss: Optional[Tensor] = field(init=False, default=None)
5556
extra: Dict[str, Any] = field(default_factory=dict)
57+
was_dict: bool = False
5658

5759
def __post_init__(self) -> None:
5860
self._clone_loss()
@@ -68,6 +70,7 @@ def from_training_step_output(
6870
) -> "ClosureResult":
6971
closure_loss, extra = None, {}
7072

73+
was_dict = False
7174
if isinstance(training_step_output, dict):
7275
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
7376
closure_loss = training_step_output.get("loss")
@@ -76,6 +79,7 @@ def from_training_step_output(
7679
"In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
7780
)
7881
extra = {k: v for k, v in training_step_output.items() if k not in ("loss", "hiddens")}
82+
was_dict = True
7983
elif isinstance(training_step_output, Tensor):
8084
closure_loss = training_step_output
8185
elif training_step_output is not None:
@@ -89,10 +93,12 @@ def from_training_step_output(
8993
# note: avoid in-place operation `x /= y` here on purpose
9094
closure_loss = closure_loss / normalize
9195

92-
return cls(closure_loss, extra=extra)
96+
return cls(closure_loss, extra=extra, was_dict=was_dict)
9397

94-
def asdict(self) -> Dict[str, Any]:
95-
return {"loss": self.loss, **self.extra}
98+
def get(self) -> Union[Optional[Tensor], Dict[str, Any]]:
99+
if self.was_dict:
100+
return {"loss": self.loss, **self.extra}
101+
return self.loss
96102

97103

98104
class Closure(AbstractClosure[ClosureResult]):
@@ -158,7 +164,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
158164
return self._result.loss
159165

160166

161-
_OUTPUTS_TYPE = Dict[int, Dict[str, Any]]
167+
_OUTPUTS_TYPE = Dict[int, Union[Optional[Tensor], Dict[str, Any]]]
162168

163169

164170
class OptimizerLoop(Loop[_OUTPUTS_TYPE]):
@@ -218,7 +224,7 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor
218224
if result.loss is not None:
219225
# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
220226
# would be skipped otherwise
221-
self._outputs[self.optimizer_idx] = result.asdict()
227+
self._outputs[self.optimizer_idx] = result.get()
222228
self.optim_progress.optimizer_position += 1
223229

224230
def on_run_end(self) -> _OUTPUTS_TYPE:

tests/loops/optimization/test_manual_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
def test_manual_result():
2424
training_step_output = {"loss": torch.tensor(25.0, requires_grad=True), "something": "jiraffe"}
2525
result = ManualResult.from_training_step_output(training_step_output)
26-
asdict = result.asdict()
26+
asdict = result.get()
2727
assert not asdict["loss"].requires_grad
2828
assert asdict["loss"] == 25
2929
assert result.extra == asdict

tests/loops/optimization/test_optimizer_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828

2929
def test_closure_result_deepcopy():
3030
closure_loss = torch.tensor(123.45)
31-
result = ClosureResult(closure_loss)
31+
result = ClosureResult(closure_loss, was_dict=True)
3232

3333
assert closure_loss.data_ptr() == result.closure_loss.data_ptr()
3434
# the `loss` is cloned so the storage is different
3535
assert closure_loss.data_ptr() != result.loss.data_ptr()
3636

37-
copy = result.asdict()
37+
copy = result.get()
3838
assert result.loss == copy["loss"]
3939
assert copy.keys() == {"loss"}
4040

tests/loops/test_evaluation_loop_flow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def backward(self, loss, optimizer, optimizer_idx):
6868

6969
assert len(train_step_out) == 1
7070
train_step_out = train_step_out[0][0]
71-
assert isinstance(train_step_out["loss"], torch.Tensor)
72-
assert train_step_out["loss"].item() == 171
71+
assert isinstance(train_step_out, torch.Tensor)
72+
assert train_step_out.item() == 171
7373

7474
# make sure the optimizer closure returns the correct things
7575
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
@@ -131,8 +131,8 @@ def backward(self, loss, optimizer, optimizer_idx):
131131

132132
assert len(train_step_out) == 1
133133
train_step_out = train_step_out[0][0]
134-
assert isinstance(train_step_out["loss"], torch.Tensor)
135-
assert train_step_out["loss"].item() == 171
134+
assert isinstance(train_step_out, torch.Tensor)
135+
assert train_step_out.item() == 171
136136

137137
# make sure the optimizer closure returns the correct things
138138
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(

tests/loops/test_flow_warnings.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

tests/loops/test_training_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ def training_step_end(self, outputs):
136136
loss = self.loss(outputs["batch"], outputs["output"])
137137
return loss
138138

139+
def training_epoch_end(self, outputs) -> None:
140+
# since `training_step_end` returns a tensor, these are tensors
141+
torch.stack(outputs).mean()
142+
139143
# No error is raised
140144
model = ValidTrainStepEndModel()
141145
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)

tests/loops/test_training_loop_flow_scalar.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,9 @@ def training_step(self, batch, batch_idx):
109109

110110
def training_epoch_end(self, outputs):
111111
self.training_epoch_end_called = True
112-
113112
# verify we saw the current num of batches
114113
assert len(outputs) == 2
115-
116-
for b in outputs:
117-
# time = 1
118-
assert len(b) == 1
119-
assert "loss" in b
120-
assert isinstance(b, dict)
114+
assert all(isinstance(o, torch.Tensor) for o in outputs)
121115

122116
def backward(self, loss, optimizer, optimizer_idx):
123117
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
@@ -151,8 +145,8 @@ def backward(self, loss, optimizer, optimizer_idx):
151145

152146
assert len(train_step_out) == 1
153147
train_step_out = train_step_out[0][0]
154-
assert isinstance(train_step_out["loss"], torch.Tensor)
155-
assert train_step_out["loss"].item() == 171
148+
assert isinstance(train_step_out, torch.Tensor)
149+
assert train_step_out.item() == 171
156150

157151
# make sure the optimizer closure returns the correct things
158152
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
@@ -184,12 +178,7 @@ def training_epoch_end(self, outputs):
184178

185179
# verify we saw the current num of batches
186180
assert len(outputs) == 2
187-
188-
for b in outputs:
189-
# time = 1
190-
assert len(b) == 1
191-
assert "loss" in b
192-
assert isinstance(b, dict)
181+
assert all(isinstance(o, torch.Tensor) for o in outputs)
193182

194183
def backward(self, loss, optimizer, optimizer_idx):
195184
return LightningModule.backward(self, loss, optimizer, optimizer_idx)
@@ -223,8 +212,8 @@ def backward(self, loss, optimizer, optimizer_idx):
223212

224213
assert len(train_step_out) == 1
225214
train_step_out = train_step_out[0][0]
226-
assert isinstance(train_step_out["loss"], torch.Tensor)
227-
assert train_step_out["loss"].item() == 171
215+
assert isinstance(train_step_out, torch.Tensor)
216+
assert train_step_out.item() == 171
228217

229218
# make sure the optimizer closure returns the correct things
230219
opt_closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(
@@ -285,6 +274,9 @@ def training_step(self, batch, batch_idx):
285274
self.log("a", loss, on_step=True, on_epoch=True)
286275
return loss if batch_idx % 2 else None
287276

277+
def training_epoch_end(self, outputs) -> None:
278+
torch.stack(outputs).mean()
279+
288280
model = TestModel()
289281
trainer = Trainer(
290282
default_root_dir=tmpdir,

tests/plugins/test_ddp_spawn_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_from_queue(self, queue: torch.multiprocessing.SimpleQueue) -> None:
4747
return super().get_from_queue(queue)
4848

4949

50-
@RunIf(skip_windows=True, skip_49370=True)
50+
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
5151
def test_ddp_cpu():
5252
"""Tests if device is set correctly when training for DDPSpawnPlugin."""
5353
trainer = Trainer(num_processes=2, fast_dev_run=True)

0 commit comments

Comments
 (0)