Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
38cb445
wip
awaelchli Jul 23, 2023
47ffc6e
simplify
awaelchli Jul 24, 2023
622bb2e
clean up
awaelchli Jul 24, 2023
b7a5a26
add utility to fsdp
awaelchli Jul 24, 2023
4acdc29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2023
3a4934b
fix
awaelchli Jul 24, 2023
5cfeebb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 24, 2023
80e4fae
make function
awaelchli Jul 25, 2023
37276b0
load
awaelchli Jul 25, 2023
82397a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2023
8781717
note
awaelchli Jul 25, 2023
7313e0d
materialize
awaelchli Jul 25, 2023
436dc73
Merge remote-tracking branch 'origin/fabric/lazy-load' into fabric/la…
awaelchli Jul 25, 2023
eb409b1
update
awaelchli Jul 25, 2023
10a59d9
recursive materialize
awaelchli Jul 25, 2023
464f131
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2023
f19f09f
fix
awaelchli Jul 25, 2023
df34359
clean up
awaelchli Jul 25, 2023
0869294
update
awaelchli Jul 25, 2023
f9e1a84
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2023
1965283
remove unused import
awaelchli Jul 26, 2023
9cc4e44
pytorch 2.0 only
awaelchli Jul 26, 2023
ae74fbb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 26, 2023
a592445
mypy
awaelchli Jul 26, 2023
7b5c501
better attribute error message
awaelchli Jul 26, 2023
b23f728
add comment about warnings
awaelchli Jul 26, 2023
b83f4a0
Add note about original author
awaelchli Jul 26, 2023
bf78b23
avoid jsonargparse deprecation message
awaelchli Jul 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
_TORCH_GREATER_EQUAL_2_1,
)
from lightning.fabric.utilities.init import _EmptyInit
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import _PATH
Expand Down Expand Up @@ -573,12 +574,17 @@ def load_checkpoint(
return metadata

if _is_full_checkpoint(path):
checkpoint = torch.load(path, map_location="cpu")
checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu")
_load_raw_module_state(checkpoint.pop(module_key), module=module, strict=strict)

if isinstance(state, Module):
return {}

if _TORCH_GREATER_EQUAL_2_0:
# Materialize lazy tensors if there are any left in the checkpoint
# The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues
checkpoint = _materialize_tensors(checkpoint)

# Load optimizer states
for optim_key, optim in optimizers.items():
# rank0_only should be false because we need to load the optimizer state on all ranks
Expand Down Expand Up @@ -821,9 +827,8 @@ def _load_raw_module_state(path_or_ckpt: Union[Path, Dict[str, Any]], module: Mo

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# This is inefficient, as multiple copies of the checkpoint are held in CPU memory at once.
# There is currently no other way because `summon_full_params` does not support write-back from rank 0 only.
state_dict = torch.load(path_or_ckpt, map_location="cpu") if not isinstance(path_or_ckpt, dict) else path_or_ckpt
# Use `lazy_load` instead of `torch.load` here to avoid storing a copy of the full checkpoint per rank
state_dict = _lazy_load(path_or_ckpt) if isinstance(path_or_ckpt, Path) else path_or_ckpt
with FSDP.summon_full_params(module, writeback=True, rank0_only=False):
module.load_state_dict(state_dict, strict=strict)

Expand Down
205 changes: 205 additions & 0 deletions src/lightning/fabric/utilities/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright 2023 MathInf GmbH
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this files from this repository except in compliance
# with the License reproduced below (also at
# http://www.apache.org/licenses/LICENSE-2.0).
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import warnings
from functools import partial
from io import BytesIO
from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch._C import _TensorMeta
from torch.nn import Parameter

from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
from lightning.fabric.utilities.types import _PATH

if TYPE_CHECKING:
from torch.storage import TypedStorage


# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
class _NotYetLoadedTensor:
def __init__(
self,
metatensor: Tensor,
archiveinfo: "_LazyLoadingUnpickler",
storageinfo: tuple,
rebuild_args: tuple,
) -> None:
self.metatensor = metatensor
self.archiveinfo = archiveinfo
self.storageinfo = storageinfo
self.rebuild_args = rebuild_args

@classmethod
def rebuild_from_type_v2(
cls,
func: Callable,
new_type: _TensorMeta,
args: tuple,
state: dict,
*,
archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
) -> Any:
ret = func(*args)
if isinstance(ret, _NotYetLoadedTensor):
old_lt = ret._load_tensor

def _load_tensor() -> Any:
t = old_lt()
return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state)

ret._load_tensor = _load_tensor # type: ignore[method-assign]
return ret
return torch._tensor._rebuild_from_type_v2(func, new_type, args, state)

@classmethod
def rebuild_parameter(
cls,
data: Any,
requires_grad: bool,
backward_hooks: OrderedDict,
*,
archiveinfo: Optional["_LazyLoadingUnpickler"] = None,
) -> Union[Tensor, "_NotYetLoadedTensor"]:
if isinstance(data, _NotYetLoadedTensor):
old_lt = data._load_tensor

def _load_tensor() -> Parameter:
t = old_lt()
return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks)

data._load_tensor = _load_tensor # type: ignore[method-assign]
return data
return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks)

@classmethod
def rebuild_tensor_v2(
cls,
storage: "TypedStorage",
storage_offset: int,
size: tuple,
stride: tuple,
requires_grad: bool,
backward_hooks: OrderedDict,
metadata: Optional[Any] = None,
*,
archiveinfo: "_LazyLoadingUnpickler",
) -> "_NotYetLoadedTensor":
rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata)
metatensor = torch._utils._rebuild_tensor_v2(
storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata
)
storageinfo = storage.archiveinfo
return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args)

def _load_tensor(self) -> Tensor:
from torch.storage import TypedStorage, UntypedStorage

name, storage_cls, fn, device, size = self.storageinfo
dtype = self.metatensor.dtype

storage = self.archiveinfo.file_reader.get_storage_from_record(
f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage
)
uts = storage._typed_storage()._untyped_storage

with warnings.catch_warnings():
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
warnings.simplefilter("ignore")
storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True)
return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args)

@classmethod
def __torch_function__(
cls,
func: Callable,
types: Sequence,
args: Sequence[Any] = (),
kwargs: Optional[Dict] = None,
) -> Any:
kwargs = kwargs or {}
loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args]
return func(*loaded_args, **kwargs)

def __getattr__(self, name: str) -> Any:
# These properties don't require materialization and can be accessed through the meta tensor directly
if name in {
"dtype",
"grad",
"grad_fn",
"layout",
"names",
"ndim",
"output_nr",
"requires_grad",
"retains_grad",
"size",
"shape",
"volatile",
}:
return getattr(self.metatensor, name)

# Materialization with contiguous is needed for quantization (see lit-gpt)
if name in {"contiguous"}:
return getattr(self._load_tensor(), name)

raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

def __repr__(self) -> str:
return f"{self.__class__.__name__}({repr(self.metatensor)})"


# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann
class _LazyLoadingUnpickler(pickle.Unpickler):
def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None:
super().__init__(file)
self.file_reader = file_reader

def find_class(self, module: str, name: str) -> Any:
if module == "torch._utils" and name == "_rebuild_tensor_v2":
return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self)
if module == "torch._tensor" and name == "_rebuild_from_type_v2":
return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self)
if module == "torch._utils" and name == "_rebuild_parameter":
return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self)
return super().find_class(module, name)

def persistent_load(self, pid: tuple) -> "TypedStorage":
from torch.storage import TypedStorage

name, cls, fn, device, size = pid
with warnings.catch_warnings():
# The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now
warnings.simplefilter("ignore")
storage = TypedStorage(dtype=cls().dtype, device="meta")
storage.archiveinfo = pid
return storage


def _lazy_load(filename: _PATH) -> Any:
if not _TORCH_GREATER_EQUAL_2_0:
raise NotImplementedError("Lazy-loading is only supported with PyTorch >= 2.0.")
file_reader = torch.PyTorchFileReader(str(filename))
with BytesIO(file_reader.get_record("data.pkl")) as pkl:
mup = _LazyLoadingUnpickler(pkl, file_reader)
return mup.load()


def _materialize_tensors(collection: Any) -> Any:
def _load_tensor(t: _NotYetLoadedTensor) -> Tensor:
return t._load_tensor()

return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor)
2 changes: 1 addition & 1 deletion tests/parity_fabric/test_parity_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,6 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float


if __name__ == "__main__":
from jsonargparse.cli import CLI
from jsonargparse import CLI

CLI(run_parity_test)
100 changes: 100 additions & 0 deletions tests/tests_fabric/utilities/test_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn

from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _NotYetLoadedTensor
from tests_fabric.helpers.runif import RunIf


@RunIf(min_torch="2.0.0")
def test_lazy_load_module(tmp_path):
model0 = nn.Linear(2, 2)
torch.save(model0.state_dict(), tmp_path / "model.pt")

model1 = nn.Linear(2, 2)
checkpoint = _lazy_load(tmp_path / "model.pt")
model1.load_state_dict(checkpoint)

assert isinstance(checkpoint["weight"], _NotYetLoadedTensor)
assert type(model0.weight.data) == torch.Tensor
assert torch.equal(model0.weight, model1.weight)
assert torch.equal(model0.bias, model1.bias)


class ATensor(torch.Tensor):
pass


@RunIf(min_torch="2.0.0")
def test_lazy_load_tensor(tmp_path):
"""Test that lazy load can handle different classes of tensors."""
expected = {
"tensor": torch.rand(2),
"parameter": nn.Parameter(torch.rand(3)),
"subclass": torch.Tensor._make_subclass(ATensor, torch.rand(4)),
}
torch.save(expected, tmp_path / "data.pt")

loaded = _lazy_load(tmp_path / "data.pt")
for t0, t1 in zip(expected.values(), loaded.values()):
assert isinstance(t1, _NotYetLoadedTensor)
t1_materialized = _materialize_tensors(t1)
assert type(t0) == type(t1_materialized)
assert torch.equal(t0, t1_materialized)


@RunIf(min_torch="2.0.0")
def test_lazy_load_mixed_state(tmp_path):
model0 = nn.Linear(2, 2)
optim0 = torch.optim.Adam(model0.parameters())
checkpoint = {
"int": 1,
"dict": {"a": 1, "b": 2},
"list": [1, 2, 3],
"pickled_model": model0,
"model": model0.state_dict(),
"optimizer": optim0.state_dict(),
}
torch.save(checkpoint, tmp_path / "checkpoint.pt")

model1 = nn.Linear(2, 2)
optim1 = torch.optim.Adam(model0.parameters())
loaded_checkpoint = _lazy_load(tmp_path / "checkpoint.pt")
model1.load_state_dict(loaded_checkpoint["model"])
optim1.load_state_dict(loaded_checkpoint["optimizer"])


@RunIf(min_torch="2.0.0")
def test_materialize_tensors(tmp_path):
# Single tensor
tensor = torch.tensor([1, 2])
torch.save(tensor, tmp_path / "tensor.pt")
loaded = _lazy_load(tmp_path / "tensor.pt")
materialized = _materialize_tensors(loaded)
assert torch.equal(materialized, tensor)
assert type(tensor) == type(materialized)

# Collection of tensors
collection = {
"tensor": torch.tensor([1, 2]),
"nested": {"int": 1, "list": [torch.tensor([3.0]), torch.tensor([4])]},
}
torch.save(collection, tmp_path / "collection.pt")
loaded = _lazy_load(tmp_path / "collection.pt")
materialized = _materialize_tensors(loaded)
assert torch.equal(materialized["tensor"], collection["tensor"])
assert torch.equal(materialized["nested"]["list"][0], collection["nested"]["list"][0])
assert torch.equal(materialized["nested"]["list"][1], collection["nested"]["list"][1])
assert materialized["nested"]["int"] == 1