Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source-pytorch/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ def _load_py_module(name: str, location: str) -> ModuleType:
("py:meth", "setup"),
("py:meth", "test_step"),
("py:meth", "toggle_optimizer"),
("py:meth", "toggled_optimizer"),
("py:class", "torch.ScriptModule"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload"),
("py:class", "torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision"),
Expand Down
2 changes: 1 addition & 1 deletion docs/source-pytorch/model/manual_optimization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ To manually optimize, do the following:
* ``optimizer.zero_grad()`` to clear the gradients from the previous training step
* ``self.manual_backward(loss)`` instead of ``loss.backward()``
* ``optimizer.step()`` to update your model parameters
* ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()`` if needed
* ``self.toggle_optimizer()`` and ``self.untoggle_optimizer()``, or ``self.toggled_optimizer()`` if needed

Here is a minimal example of manual optimization.

Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Add enable_autolog_hparams argument to Trainer ([#20593](https://github.com/Lightning-AI/pytorch-lightning/pull/20593))


- Add `toggled_optimizer(optimizer)` method to the LightningModule, which is a context manager version of `toggle_optimize` and `untoggle_optimizer` ([#20771](https://github.com/Lightning-AI/pytorch-lightning/pull/20771))


- For cross-device local checkpoints, instruct users to install `fsspec>=2025.5.0` if unavailable ([#20780](https://github.com/Lightning-AI/pytorch-lightning/pull/20780))


Expand Down
26 changes: 26 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1141,6 +1141,32 @@ def untoggle_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) ->
# save memory
self._param_requires_grad_state = {}

@contextmanager
def toggled_optimizer(self, optimizer: Union[Optimizer, LightningOptimizer]) -> Generator:
"""Makes sure only the gradients of the current optimizer's parameters are calculated in the training step to
prevent dangling gradients in multiple-optimizer setup. Combines :meth:`toggle_optimizer` and
:meth:`untoggle_optimizer` into context manager.

Args:
optimizer: The optimizer to toggle.

Example::

def training_step(...):
opt = self.optimizers()
with self.toggled_optimizer(opt):
loss = ...
opt.zero_grad()
self.manual_backward(loss)
opt.step()

"""
self.toggle_optimizer(optimizer)
try:
yield
finally:
self.untoggle_optimizer(optimizer)

def clip_gradients(
self,
optimizer: Optimizer,
Expand Down
16 changes: 16 additions & 0 deletions tests/tests_pytorch/core/test_lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def test_1_optimizer_toggle_model():
assert not model._param_requires_grad_state


def test_optimizer_toggle_model_context_manager():
"""Test toggle_model runs when only one optimizer is used."""
model = BoringModel()
trainer = Mock()
model.trainer = trainer
params = model.parameters()
optimizer = torch.optim.SGD(params, lr=0.1)
trainer.optimizers = [optimizer]

assert not model._param_requires_grad_state
# toggle optimizer was failing with a single optimizer
with model.toggled_optimizer(optimizer):
assert model._param_requires_grad_state
assert not model._param_requires_grad_state


def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmp_path):
class TestModel(BoringModel):
def __init__(self):
Expand Down
Loading