-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add lazy checkpoint loading for FSDP full-state checkpoints #18150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
28 commits
Select commit
Hold shift + click to select a range
38cb445
wip
awaelchli 47ffc6e
simplify
awaelchli 622bb2e
clean up
awaelchli b7a5a26
add utility to fsdp
awaelchli 4acdc29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3a4934b
fix
awaelchli 5cfeebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 80e4fae
make function
awaelchli 37276b0
load
awaelchli 82397a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8781717
note
awaelchli 7313e0d
materialize
awaelchli 436dc73
Merge remote-tracking branch 'origin/fabric/lazy-load' into fabric/la…
awaelchli eb409b1
update
awaelchli 10a59d9
recursive materialize
awaelchli 464f131
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f19f09f
fix
awaelchli df34359
clean up
awaelchli 0869294
update
awaelchli f9e1a84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1965283
remove unused import
awaelchli 9cc4e44
pytorch 2.0 only
awaelchli ae74fbb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a592445
mypy
awaelchli 7b5c501
better attribute error message
awaelchli b23f728
add comment about warnings
awaelchli b83f4a0
Add note about original author
awaelchli bf78b23
avoid jsonargparse deprecation message
awaelchli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Copyright 2023 MathInf GmbH | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this files from this repository except in compliance | ||
# with the License reproduced below (also at | ||
# http://www.apache.org/licenses/LICENSE-2.0). | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import pickle | ||
import warnings | ||
from functools import partial | ||
from io import BytesIO | ||
from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, TYPE_CHECKING, Union | ||
|
||
import torch | ||
from lightning_utilities.core.apply_func import apply_to_collection | ||
from torch import Tensor | ||
from torch._C import _TensorMeta | ||
from torch.nn import Parameter | ||
|
||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 | ||
from lightning.fabric.utilities.types import _PATH | ||
|
||
if TYPE_CHECKING: | ||
from torch.storage import TypedStorage | ||
|
||
|
||
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann | ||
class _NotYetLoadedTensor: | ||
def __init__( | ||
self, | ||
metatensor: Tensor, | ||
archiveinfo: "_LazyLoadingUnpickler", | ||
storageinfo: tuple, | ||
rebuild_args: tuple, | ||
) -> None: | ||
self.metatensor = metatensor | ||
self.archiveinfo = archiveinfo | ||
self.storageinfo = storageinfo | ||
self.rebuild_args = rebuild_args | ||
|
||
@classmethod | ||
def rebuild_from_type_v2( | ||
cls, | ||
func: Callable, | ||
new_type: _TensorMeta, | ||
args: tuple, | ||
state: dict, | ||
*, | ||
archiveinfo: Optional["_LazyLoadingUnpickler"] = None, | ||
) -> Any: | ||
ret = func(*args) | ||
if isinstance(ret, _NotYetLoadedTensor): | ||
old_lt = ret._load_tensor | ||
|
||
def _load_tensor() -> Any: | ||
t = old_lt() | ||
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) | ||
|
||
ret._load_tensor = _load_tensor # type: ignore[method-assign] | ||
return ret | ||
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) | ||
|
||
@classmethod | ||
def rebuild_parameter( | ||
cls, | ||
data: Any, | ||
requires_grad: bool, | ||
backward_hooks: OrderedDict, | ||
*, | ||
archiveinfo: Optional["_LazyLoadingUnpickler"] = None, | ||
) -> Union[Tensor, "_NotYetLoadedTensor"]: | ||
if isinstance(data, _NotYetLoadedTensor): | ||
old_lt = data._load_tensor | ||
|
||
def _load_tensor() -> Parameter: | ||
t = old_lt() | ||
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) | ||
|
||
data._load_tensor = _load_tensor # type: ignore[method-assign] | ||
return data | ||
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) | ||
|
||
@classmethod | ||
def rebuild_tensor_v2( | ||
cls, | ||
storage: "TypedStorage", | ||
storage_offset: int, | ||
size: tuple, | ||
stride: tuple, | ||
requires_grad: bool, | ||
backward_hooks: OrderedDict, | ||
metadata: Optional[Any] = None, | ||
*, | ||
archiveinfo: "_LazyLoadingUnpickler", | ||
) -> "_NotYetLoadedTensor": | ||
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) | ||
metatensor = torch._utils._rebuild_tensor_v2( | ||
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata | ||
) | ||
storageinfo = storage.archiveinfo | ||
return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) | ||
|
||
def _load_tensor(self) -> Tensor: | ||
from torch.storage import TypedStorage, UntypedStorage | ||
|
||
name, storage_cls, fn, device, size = self.storageinfo | ||
dtype = self.metatensor.dtype | ||
|
||
storage = self.archiveinfo.file_reader.get_storage_from_record( | ||
f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage | ||
) | ||
uts = storage._typed_storage()._untyped_storage | ||
|
||
with warnings.catch_warnings(): | ||
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now | ||
warnings.simplefilter("ignore") | ||
storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True) | ||
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) | ||
|
||
@classmethod | ||
def __torch_function__( | ||
cls, | ||
func: Callable, | ||
types: Sequence, | ||
args: Sequence[Any] = (), | ||
kwargs: Optional[Dict] = None, | ||
) -> Any: | ||
kwargs = kwargs or {} | ||
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] | ||
return func(*loaded_args, **kwargs) | ||
|
||
def __getattr__(self, name: str) -> Any: | ||
# These properties don't require materialization and can be accessed through the meta tensor directly | ||
if name in { | ||
"dtype", | ||
"grad", | ||
"grad_fn", | ||
"layout", | ||
"names", | ||
"ndim", | ||
"output_nr", | ||
"requires_grad", | ||
"retains_grad", | ||
"size", | ||
"shape", | ||
"volatile", | ||
}: | ||
return getattr(self.metatensor, name) | ||
|
||
# Materialization with contiguous is needed for quantization (see lit-gpt) | ||
if name in {"contiguous"}: | ||
return getattr(self._load_tensor(), name) | ||
|
||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") | ||
|
||
def __repr__(self) -> str: | ||
return f"{self.__class__.__name__}({repr(self.metatensor)})" | ||
|
||
|
||
# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann | ||
class _LazyLoadingUnpickler(pickle.Unpickler): | ||
def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None: | ||
super().__init__(file) | ||
self.file_reader = file_reader | ||
|
||
def find_class(self, module: str, name: str) -> Any: | ||
if module == "torch._utils" and name == "_rebuild_tensor_v2": | ||
return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) | ||
if module == "torch._tensor" and name == "_rebuild_from_type_v2": | ||
return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) | ||
if module == "torch._utils" and name == "_rebuild_parameter": | ||
return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) | ||
return super().find_class(module, name) | ||
|
||
def persistent_load(self, pid: tuple) -> "TypedStorage": | ||
from torch.storage import TypedStorage | ||
|
||
name, cls, fn, device, size = pid | ||
with warnings.catch_warnings(): | ||
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now | ||
warnings.simplefilter("ignore") | ||
storage = TypedStorage(dtype=cls().dtype, device="meta") | ||
storage.archiveinfo = pid | ||
return storage | ||
|
||
|
||
def _lazy_load(filename: _PATH) -> Any: | ||
if not _TORCH_GREATER_EQUAL_2_0: | ||
raise NotImplementedError("Lazy-loading is only supported with PyTorch >= 2.0.") | ||
file_reader = torch.PyTorchFileReader(str(filename)) | ||
with BytesIO(file_reader.get_record("data.pkl")) as pkl: | ||
mup = _LazyLoadingUnpickler(pkl, file_reader) | ||
return mup.load() | ||
|
||
|
||
def _materialize_tensors(collection: Any) -> Any: | ||
def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: | ||
return t._load_tensor() | ||
|
||
return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright The Lightning AI team. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import torch | ||
import torch.nn as nn | ||
|
||
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _NotYetLoadedTensor | ||
from tests_fabric.helpers.runif import RunIf | ||
|
||
|
||
@RunIf(min_torch="2.0.0") | ||
def test_lazy_load_module(tmp_path): | ||
model0 = nn.Linear(2, 2) | ||
torch.save(model0.state_dict(), tmp_path / "model.pt") | ||
|
||
model1 = nn.Linear(2, 2) | ||
checkpoint = _lazy_load(tmp_path / "model.pt") | ||
model1.load_state_dict(checkpoint) | ||
|
||
assert isinstance(checkpoint["weight"], _NotYetLoadedTensor) | ||
assert type(model0.weight.data) == torch.Tensor | ||
assert torch.equal(model0.weight, model1.weight) | ||
assert torch.equal(model0.bias, model1.bias) | ||
|
||
|
||
class ATensor(torch.Tensor): | ||
pass | ||
|
||
|
||
@RunIf(min_torch="2.0.0") | ||
def test_lazy_load_tensor(tmp_path): | ||
"""Test that lazy load can handle different classes of tensors.""" | ||
expected = { | ||
"tensor": torch.rand(2), | ||
"parameter": nn.Parameter(torch.rand(3)), | ||
"subclass": torch.Tensor._make_subclass(ATensor, torch.rand(4)), | ||
} | ||
torch.save(expected, tmp_path / "data.pt") | ||
|
||
loaded = _lazy_load(tmp_path / "data.pt") | ||
for t0, t1 in zip(expected.values(), loaded.values()): | ||
assert isinstance(t1, _NotYetLoadedTensor) | ||
t1_materialized = _materialize_tensors(t1) | ||
assert type(t0) == type(t1_materialized) | ||
assert torch.equal(t0, t1_materialized) | ||
|
||
|
||
@RunIf(min_torch="2.0.0") | ||
def test_lazy_load_mixed_state(tmp_path): | ||
model0 = nn.Linear(2, 2) | ||
optim0 = torch.optim.Adam(model0.parameters()) | ||
checkpoint = { | ||
"int": 1, | ||
"dict": {"a": 1, "b": 2}, | ||
"list": [1, 2, 3], | ||
"pickled_model": model0, | ||
"model": model0.state_dict(), | ||
"optimizer": optim0.state_dict(), | ||
} | ||
torch.save(checkpoint, tmp_path / "checkpoint.pt") | ||
|
||
model1 = nn.Linear(2, 2) | ||
optim1 = torch.optim.Adam(model0.parameters()) | ||
loaded_checkpoint = _lazy_load(tmp_path / "checkpoint.pt") | ||
model1.load_state_dict(loaded_checkpoint["model"]) | ||
optim1.load_state_dict(loaded_checkpoint["optimizer"]) | ||
|
||
|
||
@RunIf(min_torch="2.0.0") | ||
def test_materialize_tensors(tmp_path): | ||
# Single tensor | ||
tensor = torch.tensor([1, 2]) | ||
torch.save(tensor, tmp_path / "tensor.pt") | ||
loaded = _lazy_load(tmp_path / "tensor.pt") | ||
materialized = _materialize_tensors(loaded) | ||
assert torch.equal(materialized, tensor) | ||
assert type(tensor) == type(materialized) | ||
|
||
# Collection of tensors | ||
collection = { | ||
"tensor": torch.tensor([1, 2]), | ||
"nested": {"int": 1, "list": [torch.tensor([3.0]), torch.tensor([4])]}, | ||
} | ||
torch.save(collection, tmp_path / "collection.pt") | ||
loaded = _lazy_load(tmp_path / "collection.pt") | ||
materialized = _materialize_tensors(loaded) | ||
assert torch.equal(materialized["tensor"], collection["tensor"]) | ||
assert torch.equal(materialized["nested"]["list"][0], collection["nested"]["list"][0]) | ||
assert torch.equal(materialized["nested"]["list"][1], collection["nested"]["list"][1]) | ||
assert materialized["nested"]["int"] == 1 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.