Skip to content

Commit 220e3b8

Browse files
Add lazy checkpoint loading for FSDP full-state checkpoints (#18150)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6511ac2 commit 220e3b8

File tree

4 files changed

+315
-5
lines changed

4 files changed

+315
-5
lines changed

src/lightning/fabric/strategies/fsdp.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
_TORCH_GREATER_EQUAL_2_1,
6868
)
6969
from lightning.fabric.utilities.init import _EmptyInit
70+
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
7071
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
7172
from lightning.fabric.utilities.seed import reset_seed
7273
from lightning.fabric.utilities.types import _PATH
@@ -573,12 +574,17 @@ def load_checkpoint(
573574
return metadata
574575

575576
if _is_full_checkpoint(path):
576-
checkpoint = torch.load(path, map_location="cpu")
577+
checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu")
577578
_load_raw_module_state(checkpoint.pop(module_key), module=module, strict=strict)
578579

579580
if isinstance(state, Module):
580581
return {}
581582

583+
if _TORCH_GREATER_EQUAL_2_0:
584+
# Materialize lazy tensors if there are any left in the checkpoint
585+
# The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
586+
checkpoint = _materialize_tensors(checkpoint)
587+
582588
# Load optimizer states
583589
for optim_key, optim in optimizers.items():
584590
# rank0_only should be false because we need to load the optimizer state on all ranks
@@ -821,9 +827,8 @@ def _load_raw_module_state(path_or_ckpt: Union[Path, Dict[str, Any]], module: Mo
821827

822828
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
823829

824-
# This is inefficient, as multiple copies of the checkpoint are held in CPU memory at once.
825-
# There is currently no other way because `summon_full_params` does not support write-back from rank 0 only.
826-
state_dict = torch.load(path_or_ckpt, map_location="cpu") if not isinstance(path_or_ckpt, dict) else path_or_ckpt
830+
# Use `lazy_load` instead of `torch.load` here to avoid storing a copy of the full checkpoint per rank
831+
state_dict = _lazy_load(path_or_ckpt) if isinstance(path_or_ckpt, Path) else path_or_ckpt
827832
with FSDP.summon_full_params(module, writeback=True, rank0_only=False):
828833
module.load_state_dict(state_dict, strict=strict)
829834

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
# Copyright 2023 MathInf GmbH
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this files from this repository except in compliance
5+
# with the License reproduced below (also at
6+
# http://www.apache.org/licenses/LICENSE-2.0).
7+
#
8+
# Unless required by applicable law or agreed to in writing, software
9+
# distributed under the License is distributed on an "AS IS" BASIS,
10+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
# See the License for the specific language governing permissions and
12+
# limitations under the License.
13+
import pickle
14+
import warnings
15+
from functools import partial
16+
from io import BytesIO
17+
from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, TYPE_CHECKING, Union
18+
19+
import torch
20+
from lightning_utilities.core.apply_func import apply_to_collection
21+
from torch import Tensor
22+
from torch._C import _TensorMeta
23+
from torch.nn import Parameter
24+
25+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
26+
from lightning.fabric.utilities.types import _PATH
27+
28+
if TYPE_CHECKING:
29+
from torch.storage import TypedStorage
30+
31+
32+
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
33+
class _NotYetLoadedTensor:
34+
def __init__(
35+
self,
36+
metatensor: Tensor,
37+
archiveinfo: "_LazyLoadingUnpickler",
38+
storageinfo: tuple,
39+
rebuild_args: tuple,
40+
) -> None:
41+
self.metatensor = metatensor
42+
self.archiveinfo = archiveinfo
43+
self.storageinfo = storageinfo
44+
self.rebuild_args = rebuild_args
45+
46+
@classmethod
47+
def rebuild_from_type_v2(
48+
cls,
49+
func: Callable,
50+
new_type: _TensorMeta,
51+
args: tuple,
52+
state: dict,
53+
*,
54+
archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
55+
) -> Any:
56+
ret = func(*args)
57+
if isinstance(ret, _NotYetLoadedTensor):
58+
old_lt = ret._load_tensor
59+
60+
def _load_tensor() -> Any:
61+
t = old_lt()
62+
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state)
63+
64+
ret._load_tensor = _load_tensor # type: ignore[method-assign]
65+
return ret
66+
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)
67+
68+
@classmethod
69+
def rebuild_parameter(
70+
cls,
71+
data: Any,
72+
requires_grad: bool,
73+
backward_hooks: OrderedDict,
74+
*,
75+
archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
76+
) -> Union[Tensor, "_NotYetLoadedTensor"]:
77+
if isinstance(data, _NotYetLoadedTensor):
78+
old_lt = data._load_tensor
79+
80+
def _load_tensor() -> Parameter:
81+
t = old_lt()
82+
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)
83+
84+
data._load_tensor = _load_tensor # type: ignore[method-assign]
85+
return data
86+
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)
87+
88+
@classmethod
89+
def rebuild_tensor_v2(
90+
cls,
91+
storage: "TypedStorage",
92+
storage_offset: int,
93+
size: tuple,
94+
stride: tuple,
95+
requires_grad: bool,
96+
backward_hooks: OrderedDict,
97+
metadata: Optional[Any] = None,
98+
*,
99+
archiveinfo: "_LazyLoadingUnpickler",
100+
) -> "_NotYetLoadedTensor":
101+
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata)
102+
metatensor = torch._utils._rebuild_tensor_v2(
103+
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata
104+
)
105+
storageinfo = storage.archiveinfo
106+
return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)
107+
108+
def _load_tensor(self) -> Tensor:
109+
from torch.storage import TypedStorage, UntypedStorage
110+
111+
name, storage_cls, fn, device, size = self.storageinfo
112+
dtype = self.metatensor.dtype
113+
114+
storage = self.archiveinfo.file_reader.get_storage_from_record(
115+
f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage
116+
)
117+
uts = storage._typed_storage()._untyped_storage
118+
119+
with warnings.catch_warnings():
120+
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
121+
warnings.simplefilter("ignore")
122+
storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True)
123+
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)
124+
125+
@classmethod
126+
def __torch_function__(
127+
cls,
128+
func: Callable,
129+
types: Sequence,
130+
args: Sequence[Any] = (),
131+
kwargs: Optional[Dict] = None,
132+
) -> Any:
133+
kwargs = kwargs or {}
134+
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args]
135+
return func(*loaded_args, **kwargs)
136+
137+
def __getattr__(self, name: str) -> Any:
138+
# These properties don't require materialization and can be accessed through the meta tensor directly
139+
if name in {
140+
"dtype",
141+
"grad",
142+
"grad_fn",
143+
"layout",
144+
"names",
145+
"ndim",
146+
"output_nr",
147+
"requires_grad",
148+
"retains_grad",
149+
"size",
150+
"shape",
151+
"volatile",
152+
}:
153+
return getattr(self.metatensor, name)
154+
155+
# Materialization with contiguous is needed for quantization (see lit-gpt)
156+
if name in {"contiguous"}:
157+
return getattr(self._load_tensor(), name)
158+
159+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
160+
161+
def __repr__(self) -> str:
162+
return f"{self.__class__.__name__}({repr(self.metatensor)})"
163+
164+
165+
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
166+
class _LazyLoadingUnpickler(pickle.Unpickler):
167+
def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None:
168+
super().__init__(file)
169+
self.file_reader = file_reader
170+
171+
def find_class(self, module: str, name: str) -> Any:
172+
if module == "torch._utils" and name == "_rebuild_tensor_v2":
173+
return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)
174+
if module == "torch._tensor" and name == "_rebuild_from_type_v2":
175+
return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)
176+
if module == "torch._utils" and name == "_rebuild_parameter":
177+
return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)
178+
return super().find_class(module, name)
179+
180+
def persistent_load(self, pid: tuple) -> "TypedStorage":
181+
from torch.storage import TypedStorage
182+
183+
name, cls, fn, device, size = pid
184+
with warnings.catch_warnings():
185+
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
186+
warnings.simplefilter("ignore")
187+
storage = TypedStorage(dtype=cls().dtype, device="meta")
188+
storage.archiveinfo = pid
189+
return storage
190+
191+
192+
def _lazy_load(filename: _PATH) -> Any:
193+
if not _TORCH_GREATER_EQUAL_2_0:
194+
raise NotImplementedError("Lazy-loading is only supported with PyTorch >= 2.0.")
195+
file_reader = torch.PyTorchFileReader(str(filename))
196+
with BytesIO(file_reader.get_record("data.pkl")) as pkl:
197+
mup = _LazyLoadingUnpickler(pkl, file_reader)
198+
return mup.load()
199+
200+
201+
def _materialize_tensors(collection: Any) -> Any:
202+
def _load_tensor(t: _NotYetLoadedTensor) -> Tensor:
203+
return t._load_tensor()
204+
205+
return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor)

tests/parity_fabric/test_parity_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,6 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float
161161

162162

163163
if __name__ == "__main__":
164-
from jsonargparse.cli import CLI
164+
from jsonargparse import CLI
165165

166166
CLI(run_parity_test)
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright The Lightning AI team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import torch
15+
import torch.nn as nn
16+
17+
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _NotYetLoadedTensor
18+
from tests_fabric.helpers.runif import RunIf
19+
20+
21+
@RunIf(min_torch="2.0.0")
22+
def test_lazy_load_module(tmp_path):
23+
model0 = nn.Linear(2, 2)
24+
torch.save(model0.state_dict(), tmp_path / "model.pt")
25+
26+
model1 = nn.Linear(2, 2)
27+
checkpoint = _lazy_load(tmp_path / "model.pt")
28+
model1.load_state_dict(checkpoint)
29+
30+
assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
31+
assert type(model0.weight.data) == torch.Tensor
32+
assert torch.equal(model0.weight, model1.weight)
33+
assert torch.equal(model0.bias, model1.bias)
34+
35+
36+
class ATensor(torch.Tensor):
37+
pass
38+
39+
40+
@RunIf(min_torch="2.0.0")
41+
def test_lazy_load_tensor(tmp_path):
42+
"""Test that lazy load can handle different classes of tensors."""
43+
expected = {
44+
"tensor": torch.rand(2),
45+
"parameter": nn.Parameter(torch.rand(3)),
46+
"subclass": torch.Tensor._make_subclass(ATensor, torch.rand(4)),
47+
}
48+
torch.save(expected, tmp_path / "data.pt")
49+
50+
loaded = _lazy_load(tmp_path / "data.pt")
51+
for t0, t1 in zip(expected.values(), loaded.values()):
52+
assert isinstance(t1, _NotYetLoadedTensor)
53+
t1_materialized = _materialize_tensors(t1)
54+
assert type(t0) == type(t1_materialized)
55+
assert torch.equal(t0, t1_materialized)
56+
57+
58+
@RunIf(min_torch="2.0.0")
59+
def test_lazy_load_mixed_state(tmp_path):
60+
model0 = nn.Linear(2, 2)
61+
optim0 = torch.optim.Adam(model0.parameters())
62+
checkpoint = {
63+
"int": 1,
64+
"dict": {"a": 1, "b": 2},
65+
"list": [1, 2, 3],
66+
"pickled_model": model0,
67+
"model": model0.state_dict(),
68+
"optimizer": optim0.state_dict(),
69+
}
70+
torch.save(checkpoint, tmp_path / "checkpoint.pt")
71+
72+
model1 = nn.Linear(2, 2)
73+
optim1 = torch.optim.Adam(model0.parameters())
74+
loaded_checkpoint = _lazy_load(tmp_path / "checkpoint.pt")
75+
model1.load_state_dict(loaded_checkpoint["model"])
76+
optim1.load_state_dict(loaded_checkpoint["optimizer"])
77+
78+
79+
@RunIf(min_torch="2.0.0")
80+
def test_materialize_tensors(tmp_path):
81+
# Single tensor
82+
tensor = torch.tensor([1, 2])
83+
torch.save(tensor, tmp_path / "tensor.pt")
84+
loaded = _lazy_load(tmp_path / "tensor.pt")
85+
materialized = _materialize_tensors(loaded)
86+
assert torch.equal(materialized, tensor)
87+
assert type(tensor) == type(materialized)
88+
89+
# Collection of tensors
90+
collection = {
91+
"tensor": torch.tensor([1, 2]),
92+
"nested": {"int": 1, "list": [torch.tensor([3.0]), torch.tensor([4])]},
93+
}
94+
torch.save(collection, tmp_path / "collection.pt")
95+
loaded = _lazy_load(tmp_path / "collection.pt")
96+
materialized = _materialize_tensors(loaded)
97+
assert torch.equal(materialized["tensor"], collection["tensor"])
98+
assert torch.equal(materialized["nested"]["list"][0], collection["nested"]["list"][0])
99+
assert torch.equal(materialized["nested"]["list"][1], collection["nested"]["list"][1])
100+
assert materialized["nested"]["int"] == 1

0 commit comments

Comments
 (0)