diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index c0f623dda4730..ea624510f1611 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -38,7 +38,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an attribute error when loading a checkpoint into a quantized model using the `_lazy_load()` function ([#20121](https://github.com/Lightning-AI/lightning/pull/20121)) -- +- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019)) diff --git a/src/lightning/fabric/utilities/optimizer.py b/src/lightning/fabric/utilities/optimizer.py index e2605ceca4670..2c57ec9d1f64a 100644 --- a/src/lightning/fabric/utilities/optimizer.py +++ b/src/lightning/fabric/utilities/optimizer.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import MutableMapping from typing import Iterable -from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.optim import Optimizer -from lightning.fabric.utilities.apply_func import move_data_to_device +from lightning.fabric.utilities.apply_func import apply_to_collection, move_data_to_device from lightning.fabric.utilities.types import _DEVICE @@ -31,4 +31,12 @@ def _optimizers_to_device(optimizers: Iterable[Optimizer], device: _DEVICE) -> N def _optimizer_to_device(optimizer: Optimizer, device: _DEVICE) -> None: """Moves the state of a single optimizer to the device.""" for p, v in optimizer.state.items(): - optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True) + if not isinstance(v, MutableMapping): + # Support for custom optimizers + optimizer.state[p] = apply_to_collection(v, Tensor, move_data_to_device, device, allow_frozen=True) + continue + for key, val in v.items(): + # The 'step' parameter needs to remain unmoved (possibly on the CPU) since that is where the optimizer + # needs it. See https://github.com/pytorch/pytorch/issues/74424 + if key != "step": + v[key] = move_data_to_device(val, device) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index eba67ebffcbbe..d1e5c08411792 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the device of the logged value if needed ([#19814](https://github.com/Lightning-AI/pytorch-lightning/issues/19814)) +- Fixed `_optimizer_to_device` logic for special 'step' key in optimizer state causing performance regression ([#20019](https://github.com/Lightning-AI/lightning/pull/20019)) + - Fixed parameter counts in `ModelSummary` when model has distributed parameters (DTensor) ([#20163](https://github.com/Lightning-AI/pytorch-lightning/pull/20163)) diff --git a/tests/tests_fabric/utilities/test_optimizer.py b/tests/tests_fabric/utilities/test_optimizer.py index 3aa78d507c346..83c7ed44120b9 100644 --- a/tests/tests_fabric/utilities/test_optimizer.py +++ b/tests/tests_fabric/utilities/test_optimizer.py @@ -1,36 +1,86 @@ -import collections import dataclasses +import pytest import torch from lightning.fabric.utilities.optimizer import _optimizer_to_device from torch import Tensor +from tests_fabric.helpers.runif import RunIf -def test_optimizer_to_device(): - @dataclasses.dataclass(frozen=True) + +@pytest.mark.parametrize( + "optimizer_class", + [ + torch.optim.Adam, + torch.optim.AdamW, + torch.optim.SGD, + torch.optim.RMSprop, + torch.optim.Adagrad, + torch.optim.Adadelta, + torch.optim.Adamax, + ], +) +@pytest.mark.parametrize( + "src_device", + [ + torch.device("cpu"), + pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)), + ], +) +@pytest.mark.parametrize( + "dst_device", + [ + torch.device("cpu"), + pytest.param(torch.device("cuda"), marks=RunIf(min_cuda_gpus=1)), + ], +) +def test_optimizer_to_device(optimizer_class, src_device, dst_device): + # Optimizer with no state initialized + model = torch.nn.Linear(2, 2, device=src_device) + optimizer = optimizer_class(model.parameters(), lr=0.1) + _optimizer_to_device(optimizer, dst_device) + _assert_opt_parameters_on_device(optimizer, dst_device) + + # Optimizer with state initialized + model = torch.nn.Linear(2, 2, device=src_device) + optimizer = optimizer_class(model.parameters(), lr=0.1) + model(torch.randn(2, 2, device=src_device)).sum().backward() + optimizer.step() + _optimizer_to_device(optimizer, dst_device) + _assert_opt_parameters_on_device(optimizer, dst_device) + + +def _assert_opt_parameters_on_device(opt, device): + for _, v in opt.state.items(): + for key, item in v.items(): + if not isinstance(item, Tensor): + continue + if key == "step": + # The "step" tensor needs to remain on CPU + assert item.device.type == "cpu" + else: + assert item.device.type == device.type + + +@RunIf(min_cuda_gpus=1) +@pytest.mark.parametrize("frozen", [True, False]) +def test_optimizer_to_device_with_dataclass_in_state(frozen): + src_device = torch.device("cpu") + dst_device = torch.device("cuda") + model = torch.nn.Linear(32, 2, device=src_device) + + @dataclasses.dataclass(frozen=frozen) class FooState: - bar: int + integer: int + tensor: Tensor class TestOptimizer(torch.optim.SGD): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.state["dummy"] = torch.tensor(0) - self.state["frozen"] = FooState(0) - - layer = torch.nn.Linear(32, 2) - opt = TestOptimizer(layer.parameters(), lr=0.1) - _optimizer_to_device(opt, "cpu") - if torch.cuda.is_available(): - _optimizer_to_device(opt, "cuda") - assert_opt_parameters_on_device(opt, "cuda") - - -def assert_opt_parameters_on_device(opt, device: str): - for param in opt.state.values(): - # Not sure there are any global tensors in the state dict - if isinstance(param, Tensor): - assert param.data.device.type == device - elif isinstance(param, collections.abc.Mapping): - for subparam in param.values(): - if isinstance(subparam, Tensor): - assert param.data.device.type == device + self.state[model.weight] = {"dummy": torch.tensor(0)} + self.state[model.bias] = FooState(0, torch.tensor(0)) + + optimizer = TestOptimizer(model.parameters(), lr=0.1) + _optimizer_to_device(optimizer, dst_device) + assert optimizer.state[model.weight]["dummy"].device.type == dst_device.type + assert optimizer.state[model.bias].tensor.device.type == ("cpu" if frozen else dst_device.type)