Skip to content

Commit ae2138d

Browse files
awaelchlilantiga
authored andcommitted
Refresh _FabricOptimizer.__dict__ when loading a state dict (#18488)
(cherry picked from commit cf437ed)
1 parent a0cfb12 commit ae2138d

File tree

3 files changed

+46
-5
lines changed

3 files changed

+46
-5
lines changed

src/lightning/fabric/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818
- Fixed FSDP full-precision `param_dtype` training (`16-mixed`, `bf16-mixed` and `32-true` configurations) to avoid FSDP assertion errors with PyTorch < 2.0 ([#18278](https://github.com/Lightning-AI/lightning/pull/18278))
1919

2020

21+
- Fixed an issue causing the `_FabricOptimizer.state` to remain outdated after loading with `load_state_dict` ([#18488](https://github.com/Lightning-AI/lightning/pull/18488))
22+
23+
2124
## [2.0.7] - 2023-08-14
2225

2326
### Changed

src/lightning/fabric/wrappers.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,11 @@ def __init__(self, optimizer: Optimizer, strategy: Strategy, callbacks: Optional
4848
optimizer: The optimizer to wrap
4949
strategy: Reference to the strategy for handling the optimizer step
5050
"""
51-
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
52-
# not want to call on destruction of the `_FabricOptimizer
53-
self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("state_dict", "step", "__del__")}
5451
self.__class__ = type("Fabric" + optimizer.__class__.__name__, (self.__class__, optimizer.__class__), {})
5552
self._optimizer = optimizer
5653
self._strategy = strategy
5754
self._callbacks = callbacks or []
55+
self._refresh()
5856

5957
@property
6058
def optimizer(self) -> Optimizer:
@@ -63,6 +61,12 @@ def optimizer(self) -> Optimizer:
6361
def state_dict(self) -> Dict[str, Tensor]:
6462
return self._strategy.get_optimizer_state(self.optimizer)
6563

64+
def load_state_dict(self, state_dict: Dict[str, Tensor]) -> None:
65+
self.optimizer.load_state_dict(state_dict)
66+
# `Optimizer.load_state_dict` modifies `optimizer.__dict__`, so we need to update the `__dict__` on
67+
# this wrapper
68+
self._refresh()
69+
6670
def step(self, closure: Optional[Callable] = None) -> Any:
6771
kwargs = {"closure": closure} if closure is not None else {}
6872
if hasattr(self._strategy, "model") and isinstance(self._strategy.model, Optimizable):
@@ -80,6 +84,21 @@ def step(self, closure: Optional[Callable] = None) -> Any:
8084
hook(strategy=self._strategy, optimizer=optimizer)
8185
return output
8286

87+
def _refresh(self) -> None:
88+
"""Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer.
89+
90+
This is only needed to present the user with an updated view in case they inspect the state of this wrapper.
91+
"""
92+
# `__del__` is skipped in case the optimizer has implemented custom destructor logic which we would
93+
# not want to call on destruction of the `_FabricOptimizer
94+
self.__dict__.update(
95+
{
96+
k: v
97+
for k, v in self.optimizer.__dict__.items()
98+
if k not in ("load_state_dict", "state_dict", "step", "__del__")
99+
}
100+
)
101+
83102

84103
class _FabricModule(_DeviceDtypeModuleMixin):
85104
def __init__(

tests/tests_fabric/test_wrappers.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -306,16 +306,35 @@ def test_fabric_optimizer_wraps():
306306

307307
def test_fabric_optimizer_state_dict():
308308
"""Test that the FabricOptimizer calls into the strategy to collect the state."""
309-
optimizer = Mock()
309+
optimizer = Mock(spec=torch.optim.Adam)
310310
strategy = Mock()
311311
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)
312312
fabric_optimizer.state_dict()
313313
strategy.get_optimizer_state.assert_called_with(optimizer)
314314

315315

316+
def test_fabric_optimizer_load_state_dict():
317+
"""Test that the FabricOptimizer can load the state dict on the wrapped optimizer and update its
318+
internal `__dict__`."""
319+
model = torch.nn.Linear(1, 1)
320+
optimizer = torch.optim.Adam(model.parameters())
321+
assert not optimizer.state # a fresh optimizer has no state
322+
model(torch.rand(1)).backward()
323+
optimizer.step()
324+
assert optimizer.state
325+
state_dict = optimizer.state_dict()
326+
327+
optimizer = torch.optim.Adam(model.parameters()) # fresh optimizer
328+
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=Mock())
329+
assert not fabric_optimizer.state # a fresh optimizer has no state
330+
fabric_optimizer.load_state_dict(state_dict)
331+
assert fabric_optimizer.state
332+
assert fabric_optimizer.optimizer.state_dict()["state"] == state_dict["state"]
333+
334+
316335
def test_fabric_optimizer_steps():
317336
"""Test that the FabricOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
318-
optimizer = Mock()
337+
optimizer = Mock(spec=torch.optim.Adam)
319338
strategy = Mock(spec=["optimizer_step"])
320339
strategy.optimizer_step.return_value = 123
321340
fabric_optimizer = _FabricOptimizer(optimizer=optimizer, strategy=strategy)

0 commit comments

Comments
 (0)