13
13
# limitations under the License.
14
14
from dataclasses import dataclass , field
15
15
from functools import partial
16
- from typing import Any , Callable , Dict , List , Optional , Tuple
16
+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union
17
17
18
18
import torch
19
19
from torch import Tensor
@@ -48,11 +48,13 @@ class ClosureResult(OutputResult):
48
48
closure_loss: The loss with a graph attached.
49
49
loss: A detached copy of the closure loss.
50
50
extra: Any keys other than the loss returned.
51
+ was_dict: Whether the training step output was a dictionary.
51
52
"""
52
53
53
54
closure_loss : Optional [Tensor ]
54
55
loss : Optional [Tensor ] = field (init = False , default = None )
55
56
extra : Dict [str , Any ] = field (default_factory = dict )
57
+ was_dict : bool = False
56
58
57
59
def __post_init__ (self ) -> None :
58
60
self ._clone_loss ()
@@ -68,6 +70,7 @@ def from_training_step_output(
68
70
) -> "ClosureResult" :
69
71
closure_loss , extra = None , {}
70
72
73
+ was_dict = False
71
74
if isinstance (training_step_output , dict ):
72
75
# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`
73
76
closure_loss = training_step_output .get ("loss" )
@@ -76,6 +79,7 @@ def from_training_step_output(
76
79
"In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present"
77
80
)
78
81
extra = {k : v for k , v in training_step_output .items () if k not in ("loss" , "hiddens" )}
82
+ was_dict = True
79
83
elif isinstance (training_step_output , Tensor ):
80
84
closure_loss = training_step_output
81
85
elif training_step_output is not None :
@@ -89,10 +93,12 @@ def from_training_step_output(
89
93
# note: avoid in-place operation `x /= y` here on purpose
90
94
closure_loss = closure_loss / normalize
91
95
92
- return cls (closure_loss , extra = extra )
96
+ return cls (closure_loss , extra = extra , was_dict = was_dict )
93
97
94
- def asdict (self ) -> Dict [str , Any ]:
95
- return {"loss" : self .loss , ** self .extra }
98
+ def get (self ) -> Union [Optional [Tensor ], Dict [str , Any ]]:
99
+ if self .was_dict :
100
+ return {"loss" : self .loss , ** self .extra }
101
+ return self .loss
96
102
97
103
98
104
class Closure (AbstractClosure [ClosureResult ]):
@@ -158,7 +164,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Optional[Tensor]:
158
164
return self ._result .loss
159
165
160
166
161
- _OUTPUTS_TYPE = Dict [int , Dict [str , Any ]]
167
+ _OUTPUTS_TYPE = Dict [int , Union [ Optional [ Tensor ], Dict [str , Any ] ]]
162
168
163
169
164
170
class OptimizerLoop (Loop [_OUTPUTS_TYPE ]):
@@ -218,7 +224,7 @@ def advance(self, batch: Any, *args: Any, **kwargs: Any) -> None: # type: ignor
218
224
if result .loss is not None :
219
225
# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch
220
226
# would be skipped otherwise
221
- self ._outputs [self .optimizer_idx ] = result .asdict ()
227
+ self ._outputs [self .optimizer_idx ] = result .get ()
222
228
self .optim_progress .optimizer_position += 1
223
229
224
230
def on_run_end (self ) -> _OUTPUTS_TYPE :
0 commit comments