From 38cb445c6db6efa7e4c9829ea9874c48383963bf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Sun, 23 Jul 2023 14:07:31 +0200 Subject: [PATCH 01/27] wip --- src/lightning/fabric/utilities/load.py | 178 ++++++++++++++++++++++ tests/tests_fabric/utilities/test_load.py | 39 +++++ 2 files changed, 217 insertions(+) create mode 100644 src/lightning/fabric/utilities/load.py create mode 100644 tests/tests_fabric/utilities/test_load.py diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py new file mode 100644 index 0000000000000..ff56e6f28c6b9 --- /dev/null +++ b/src/lightning/fabric/utilities/load.py @@ -0,0 +1,178 @@ +# Copyright 2023 MathInf GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this files from this repository except in compliance +# with the License reproduced below (also at +# http://www.apache.org/licenses/LICENSE-2.0). +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +import pickle +import warnings +from io import BytesIO +from typing import IO, Any, Callable, Sequence, Optional, Dict + +import torch +import torch.utils._device +from torch import Tensor +from torch.storage import TypedStorage, UntypedStorage + +from lightning.fabric.utilities.types import _PATH + + +# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann +class _NotYetLoadedTensor: + def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): + self.metatensor = metatensor + self.archiveinfo = archiveinfo + self.storageinfo = storageinfo + self.rebuild_args = rebuild_args + + @classmethod + def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): + ret = func(*args) + if isinstance(ret, _NotYetLoadedTensor): + old_lt = ret._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) + + ret._load_tensor = _load_tensor + return ret + return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) + + @classmethod + def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): + if isinstance(data, _NotYetLoadedTensor): + old_lt = data._load_tensor + + def _load_tensor(): + t = old_lt() + return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) + + data._load_tensor = _load_tensor + return data + return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) + + @classmethod + def rebuild_tensor_v2( + cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None + ): + rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) + metatensor = torch._utils._rebuild_tensor_v2( + storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata + ) + storageinfo = storage.archiveinfo + return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) + + def _load_tensor(self) -> Tensor: + name, storage_cls, fn, device, size = self.storageinfo + dtype = self.metatensor.dtype + + uts = ( + self.archiveinfo.zipfile_context.zf.get_storage_from_record( + f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage + ) + ._typed_storage() + ._untyped_storage + ) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + storage = TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) + return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) + + @classmethod + def __torch_function__( + cls, + func: Callable, + types: Sequence, + args: Sequence[Any] = (), + kwargs: Optional[Dict] = None, + ) -> Any: + if kwargs is None: # TODO: simplify + kwargs = {} + loaded_args = [(a._load_tensor() if isinstance(a, _NotYetLoadedTensor) else a) for a in args] + return func(*loaded_args, **kwargs) + + def __getattr__(self, name: str) -> Any: + # properties + ## TODO: device, is_...?? + ## TODO: mH, mT, H, T, data, imag, real + ## name ??? + if name in { + "dtype", + "grad", + "grad_fn", + "layout", + "names", + "ndim", + "output_nr", + "requires_grad", + "retains_grad", + "shape", + "volatile", + }: + return getattr(self.metatensor, name) + if name in {"size"}: + return getattr(self.metatensor, name) + # materializing with contiguous is needed for quantization + if name in {"contiguous"}: + return getattr(self._load_tensor(), name) + + raise AttributeError(f"{type(self)} does not have {name}") + + def __repr__(self) -> str: + return f"_NotYetLoadedTensor({repr(self.metatensor)})" + + +class _LazyLoadingUnpickler(pickle.Unpickler): + def __init__(self, file: IO, zipfile_context) -> None: + super().__init__(file) + self.zipfile_context = zipfile_context # TODO: is this needed? + + def find_class(self, module: str, name: str) -> Any: + res = super().find_class(module, name) # TODO: move to bottom + if module == "torch._utils" and name == "_rebuild_tensor_v2": + return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) + if module == "torch._tensor" and name == "_rebuild_from_type_v2": + return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) + if module == "torch._utils" and name == "_rebuild_parameter": + return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) + return res + + def persistent_load(self, pid) -> TypedStorage: + name, cls, fn, device, size = pid + with warnings.catch_warnings(): + warnings.simplefilter("ignore") # TODO: needed? + storage = TypedStorage(dtype=cls().dtype, device="meta") + storage.archiveinfo = pid + return storage + + +# TODO: make functional? +class _LazyLoad: + def __init__(self, filename: _PATH) -> None: + self.zf = torch._C.PyTorchFileReader(str(filename)) + with BytesIO(self.zf.get_record("data.pkl")) as pkl: + mup = _LazyLoadingUnpickler(pkl, self) + self.sd = mup.load() + + def __enter__(self): + return self.sd + + def __exit__(self, exc_type, exc_val, exc_tb): + del self.zf # I don't think there is a way to force closing... + self.zf = None + + +def _lazy_load(filename: _PATH) -> Any: + zf = torch._C.PyTorchFileReader(str(filename)) + with BytesIO(zf.get_record("data.pkl")) as pkl: + mup = _LazyLoadingUnpickler(pkl, None) + return mup.load() diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py new file mode 100644 index 0000000000000..d97dc93b03fb5 --- /dev/null +++ b/tests/tests_fabric/utilities/test_load.py @@ -0,0 +1,39 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn as nn + +from fabric.utilities.load import _lazy_load, _LazyLoad + + +def test_lazy_load_module(tmp_path): + model0 = nn.Linear(2, 2) + torch.save(model0.state_dict(), tmp_path / "model.pt") + + model1 = nn.Linear(2, 2) + with _LazyLoad(tmp_path / "model.pt") as checkpoint: + model1.load_state_dict(checkpoint) + + assert torch.equal(model0.weight, model1.weight) + assert torch.equal(model0.bias, model1.bias) + + +def test_lazy_load_tensor(tmp_path): + data = torch.rand(2, 2) + torch.save({"data": data}, tmp_path / "data.pt") + + with _LazyLoad(tmp_path / "data.pt") as checkpoint: + loaded_data = checkpoint["data"] + + assert torch.equal(loaded_data, data) From 47ffc6ed7ce9c58702960a203bb3b8d94fbbd6ca Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Jul 2023 16:58:58 +0200 Subject: [PATCH 02/27] simplify --- src/lightning/fabric/utilities/load.py | 15 ++++++------ tests/tests_fabric/utilities/test_load.py | 30 ++++++++++++++++++++--- 2 files changed, 34 insertions(+), 11 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index ff56e6f28c6b9..dafd848404522 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -10,7 +10,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from contextlib import contextmanager from functools import partial import pickle import warnings @@ -76,7 +76,7 @@ def _load_tensor(self) -> Tensor: dtype = self.metatensor.dtype uts = ( - self.archiveinfo.zipfile_context.zf.get_storage_from_record( + self.archiveinfo.zipfile.get_storage_from_record( f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage ) ._typed_storage() @@ -132,9 +132,9 @@ def __repr__(self) -> str: class _LazyLoadingUnpickler(pickle.Unpickler): - def __init__(self, file: IO, zipfile_context) -> None: + def __init__(self, file: IO, zipfile) -> None: super().__init__(file) - self.zipfile_context = zipfile_context # TODO: is this needed? + self.zipfile = zipfile # TODO: is this needed? def find_class(self, module: str, name: str) -> Any: res = super().find_class(module, name) # TODO: move to bottom @@ -160,7 +160,7 @@ class _LazyLoad: def __init__(self, filename: _PATH) -> None: self.zf = torch._C.PyTorchFileReader(str(filename)) with BytesIO(self.zf.get_record("data.pkl")) as pkl: - mup = _LazyLoadingUnpickler(pkl, self) + mup = _LazyLoadingUnpickler(pkl, self.zf) self.sd = mup.load() def __enter__(self): @@ -171,8 +171,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.zf = None +@contextmanager def _lazy_load(filename: _PATH) -> Any: zf = torch._C.PyTorchFileReader(str(filename)) with BytesIO(zf.get_record("data.pkl")) as pkl: - mup = _LazyLoadingUnpickler(pkl, None) - return mup.load() + mup = _LazyLoadingUnpickler(pkl, zf) + yield mup.load() diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index d97dc93b03fb5..d754822967c55 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from fabric.utilities.load import _lazy_load, _LazyLoad +from fabric.utilities.load import _lazy_load, _NotYetLoadedTensor def test_lazy_load_module(tmp_path): @@ -22,9 +22,11 @@ def test_lazy_load_module(tmp_path): torch.save(model0.state_dict(), tmp_path / "model.pt") model1 = nn.Linear(2, 2) - with _LazyLoad(tmp_path / "model.pt") as checkpoint: + with _lazy_load(tmp_path / "model.pt") as checkpoint: model1.load_state_dict(checkpoint) + assert isinstance(checkpoint["weight"], _NotYetLoadedTensor) + assert type(model0.weight.data) == torch.Tensor assert torch.equal(model0.weight, model1.weight) assert torch.equal(model0.bias, model1.bias) @@ -33,7 +35,27 @@ def test_lazy_load_tensor(tmp_path): data = torch.rand(2, 2) torch.save({"data": data}, tmp_path / "data.pt") - with _LazyLoad(tmp_path / "data.pt") as checkpoint: + with _lazy_load(tmp_path / "data.pt") as checkpoint: loaded_data = checkpoint["data"] + assert torch.equal(loaded_data, data) + assert isinstance(checkpoint["data"], _NotYetLoadedTensor) - assert torch.equal(loaded_data, data) + +def test_lazy_load_mixed_state(tmp_path): + model0 = nn.Linear(2, 2) + optim0 = torch.optim.Adam(model0.parameters()) + checkpoint = { + "int": 1, + "dict": {"a": 1, "b": 2}, + "list": [1, 2, 3], + "pickled_model": model0, + "model": model0.state_dict(), + "optimizer": optim0.state_dict() + } + torch.save(checkpoint, tmp_path / "checkpoint.pt") + + model1 = nn.Linear(2, 2) + optim1 = torch.optim.Adam(model0.parameters()) + with _lazy_load(tmp_path / "checkpoint.pt") as loaded_checkpoint: + model1.load_state_dict(loaded_checkpoint["model"]) + optim1.load_state_dict(loaded_checkpoint["optimizer"]) From 622bb2e9a667843bc8b0e8db998b6460b8a2b4f8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Jul 2023 17:40:19 +0200 Subject: [PATCH 03/27] clean up --- src/lightning/fabric/utilities/load.py | 48 +++++++---------------- tests/tests_fabric/utilities/test_load.py | 23 ++++++++--- 2 files changed, 31 insertions(+), 40 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index dafd848404522..6b2383672169c 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -76,7 +76,7 @@ def _load_tensor(self) -> Tensor: dtype = self.metatensor.dtype uts = ( - self.archiveinfo.zipfile.get_storage_from_record( + self.archiveinfo.file_reader.get_storage_from_record( f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage ) ._typed_storage() @@ -95,16 +95,12 @@ def __torch_function__( args: Sequence[Any] = (), kwargs: Optional[Dict] = None, ) -> Any: - if kwargs is None: # TODO: simplify - kwargs = {} - loaded_args = [(a._load_tensor() if isinstance(a, _NotYetLoadedTensor) else a) for a in args] + kwargs = kwargs or {} + loaded_args = [(arg._load_tensor() if isinstance(arg, _NotYetLoadedTensor) else arg) for arg in args] return func(*loaded_args, **kwargs) def __getattr__(self, name: str) -> Any: - # properties - ## TODO: device, is_...?? - ## TODO: mH, mT, H, T, data, imag, real - ## name ??? + # These properties don't require materialization and can be accessed through the meta tensor directly if name in { "dtype", "grad", @@ -115,12 +111,13 @@ def __getattr__(self, name: str) -> Any: "output_nr", "requires_grad", "retains_grad", + "size", "shape", "volatile", }: return getattr(self.metatensor, name) - if name in {"size"}: - return getattr(self.metatensor, name) + + # TODO: needed for us? # materializing with contiguous is needed for quantization if name in {"contiguous"}: return getattr(self._load_tensor(), name) @@ -128,23 +125,22 @@ def __getattr__(self, name: str) -> Any: raise AttributeError(f"{type(self)} does not have {name}") def __repr__(self) -> str: - return f"_NotYetLoadedTensor({repr(self.metatensor)})" + return f"{self.__class__.__name__}({repr(self.metatensor)})" class _LazyLoadingUnpickler(pickle.Unpickler): - def __init__(self, file: IO, zipfile) -> None: + def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None: super().__init__(file) - self.zipfile = zipfile # TODO: is this needed? + self.file_reader = file_reader def find_class(self, module: str, name: str) -> Any: - res = super().find_class(module, name) # TODO: move to bottom if module == "torch._utils" and name == "_rebuild_tensor_v2": return partial(_NotYetLoadedTensor.rebuild_tensor_v2, archiveinfo=self) if module == "torch._tensor" and name == "_rebuild_from_type_v2": return partial(_NotYetLoadedTensor.rebuild_from_type_v2, archiveinfo=self) if module == "torch._utils" and name == "_rebuild_parameter": return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) - return res + return super().find_class(module, name) def persistent_load(self, pid) -> TypedStorage: name, cls, fn, device, size = pid @@ -155,25 +151,9 @@ def persistent_load(self, pid) -> TypedStorage: return storage -# TODO: make functional? -class _LazyLoad: - def __init__(self, filename: _PATH) -> None: - self.zf = torch._C.PyTorchFileReader(str(filename)) - with BytesIO(self.zf.get_record("data.pkl")) as pkl: - mup = _LazyLoadingUnpickler(pkl, self.zf) - self.sd = mup.load() - - def __enter__(self): - return self.sd - - def __exit__(self, exc_type, exc_val, exc_tb): - del self.zf # I don't think there is a way to force closing... - self.zf = None - - @contextmanager def _lazy_load(filename: _PATH) -> Any: - zf = torch._C.PyTorchFileReader(str(filename)) - with BytesIO(zf.get_record("data.pkl")) as pkl: - mup = _LazyLoadingUnpickler(pkl, zf) + file_reader = torch.PyTorchFileReader(str(filename)) + with BytesIO(file_reader.get_record("data.pkl")) as pkl: + mup = _LazyLoadingUnpickler(pkl, file_reader) yield mup.load() diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index d754822967c55..3aca5187330db 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -31,14 +31,25 @@ def test_lazy_load_module(tmp_path): assert torch.equal(model0.bias, model1.bias) +class ATensor(torch.Tensor): + pass + + def test_lazy_load_tensor(tmp_path): - data = torch.rand(2, 2) - torch.save({"data": data}, tmp_path / "data.pt") + """Test that lazy load can handle different classes of tensors.""" + expected = { + "tensor": torch.rand(2), + "parameter": nn.Parameter(torch.rand(3)), + "subclass": torch.Tensor._make_subclass(ATensor, torch.rand(4)) + } + torch.save(expected, tmp_path / "data.pt") - with _lazy_load(tmp_path / "data.pt") as checkpoint: - loaded_data = checkpoint["data"] - assert torch.equal(loaded_data, data) - assert isinstance(checkpoint["data"], _NotYetLoadedTensor) + with _lazy_load(tmp_path / "data.pt") as loaded: + for t0, t1 in zip(expected.values(), loaded.values()): + assert isinstance(t1, _NotYetLoadedTensor) + t1_materialized = t1._load_tensor() + assert type(t0) == type(t1_materialized) + assert torch.equal(t0, t1_materialized) def test_lazy_load_mixed_state(tmp_path): From b7a5a265ae99486b3149d6074d31e304baf0797e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Jul 2023 17:40:26 +0200 Subject: [PATCH 04/27] add utility to fsdp --- src/lightning/fabric/strategies/fsdp.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 848a589b0019e..4464340a15a60 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -66,6 +66,7 @@ _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1, ) +from lightning.fabric.utilities.load import _lazy_load from lightning.fabric.utilities.init import _EmptyInit from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed @@ -821,11 +822,13 @@ def _load_raw_module_state(path_or_ckpt: Union[Path, Dict[str, Any]], module: Mo from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - # This is inefficient, as multiple copies of the checkpoint are held in CPU memory at once. - # There is currently no other way because `summon_full_params` does not support write-back from rank 0 only. - state_dict = torch.load(path_or_ckpt, map_location="cpu") if not isinstance(path_or_ckpt, dict) else path_or_ckpt - with FSDP.summon_full_params(module, writeback=True, rank0_only=False): - module.load_state_dict(state_dict, strict=strict) + if isinstance(path_or_ckpt, Path): + with _lazy_load(path_or_ckpt) as state_dict: + with FSDP.summon_full_params(module, writeback=True, rank0_only=False): + module.load_state_dict(state_dict, strict=strict) + else: + with FSDP.summon_full_params(module, writeback=True, rank0_only=False): + module.load_state_dict(path_or_ckpt, strict=strict) def _no_op() -> None: From 4acdc29874a5984c73b72e7dc9394f0f229bc139 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jul 2023 15:42:33 +0000 Subject: [PATCH 05/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/fsdp.py | 7 +++---- src/lightning/fabric/utilities/load.py | 6 +++--- tests/tests_fabric/utilities/test_load.py | 5 ++--- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 4464340a15a60..db8cae681a9e2 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -66,8 +66,8 @@ _TORCH_GREATER_EQUAL_2_0, _TORCH_GREATER_EQUAL_2_1, ) -from lightning.fabric.utilities.load import _lazy_load from lightning.fabric.utilities.init import _EmptyInit +from lightning.fabric.utilities.load import _lazy_load from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH @@ -823,9 +823,8 @@ def _load_raw_module_state(path_or_ckpt: Union[Path, Dict[str, Any]], module: Mo from torch.distributed.fsdp import FullyShardedDataParallel as FSDP if isinstance(path_or_ckpt, Path): - with _lazy_load(path_or_ckpt) as state_dict: - with FSDP.summon_full_params(module, writeback=True, rank0_only=False): - module.load_state_dict(state_dict, strict=strict) + with _lazy_load(path_or_ckpt) as state_dict, FSDP.summon_full_params(module, writeback=True, rank0_only=False): + module.load_state_dict(state_dict, strict=strict) else: with FSDP.summon_full_params(module, writeback=True, rank0_only=False): module.load_state_dict(path_or_ckpt, strict=strict) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 6b2383672169c..1d9228aff44df 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -10,12 +10,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from contextlib import contextmanager -from functools import partial import pickle import warnings +from contextlib import contextmanager +from functools import partial from io import BytesIO -from typing import IO, Any, Callable, Sequence, Optional, Dict +from typing import Any, Callable, Dict, IO, Optional, Sequence import torch import torch.utils._device diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 3aca5187330db..ad6979ad5aaf1 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -13,7 +13,6 @@ # limitations under the License. import torch import torch.nn as nn - from fabric.utilities.load import _lazy_load, _NotYetLoadedTensor @@ -40,7 +39,7 @@ def test_lazy_load_tensor(tmp_path): expected = { "tensor": torch.rand(2), "parameter": nn.Parameter(torch.rand(3)), - "subclass": torch.Tensor._make_subclass(ATensor, torch.rand(4)) + "subclass": torch.Tensor._make_subclass(ATensor, torch.rand(4)), } torch.save(expected, tmp_path / "data.pt") @@ -61,7 +60,7 @@ def test_lazy_load_mixed_state(tmp_path): "list": [1, 2, 3], "pickled_model": model0, "model": model0.state_dict(), - "optimizer": optim0.state_dict() + "optimizer": optim0.state_dict(), } torch.save(checkpoint, tmp_path / "checkpoint.pt") From 3a4934b6f0ec858b4a3cbe323a1de20498a363e1 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 24 Jul 2023 17:51:58 +0200 Subject: [PATCH 06/27] fix --- tests/tests_fabric/utilities/test_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index ad6979ad5aaf1..607952dfc0728 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -13,7 +13,7 @@ # limitations under the License. import torch import torch.nn as nn -from fabric.utilities.load import _lazy_load, _NotYetLoadedTensor +from lightning.fabric.utilities.load import _lazy_load, _NotYetLoadedTensor def test_lazy_load_module(tmp_path): From 5cfeebbc133c86259c4eeb1c7d3283f0f3bdb971 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 24 Jul 2023 15:53:12 +0000 Subject: [PATCH 07/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/utilities/test_load.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 607952dfc0728..4663b7cb5e5ca 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -13,6 +13,7 @@ # limitations under the License. import torch import torch.nn as nn + from lightning.fabric.utilities.load import _lazy_load, _NotYetLoadedTensor From 80e4fae0d9dc0922ab4e4a6fa1c3f3034826b707 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 14:48:18 +0200 Subject: [PATCH 08/27] make function --- src/lightning/fabric/utilities/load.py | 11 +++++++++-- tests/tests_fabric/utilities/test_load.py | 22 +++++++++++----------- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 1d9228aff44df..556eaee2ee137 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -151,9 +151,16 @@ def persistent_load(self, pid) -> TypedStorage: return storage -@contextmanager +# @contextmanager +# def _lazy_load(filename: _PATH) -> Any: +# file_reader = torch.PyTorchFileReader(str(filename)) +# with BytesIO(file_reader.get_record("data.pkl")) as pkl: +# mup = _LazyLoadingUnpickler(pkl, file_reader) +# yield mup.load() + + def _lazy_load(filename: _PATH) -> Any: file_reader = torch.PyTorchFileReader(str(filename)) with BytesIO(file_reader.get_record("data.pkl")) as pkl: mup = _LazyLoadingUnpickler(pkl, file_reader) - yield mup.load() + return mup.load() diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 4663b7cb5e5ca..443a54afd181e 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -22,8 +22,8 @@ def test_lazy_load_module(tmp_path): torch.save(model0.state_dict(), tmp_path / "model.pt") model1 = nn.Linear(2, 2) - with _lazy_load(tmp_path / "model.pt") as checkpoint: - model1.load_state_dict(checkpoint) + checkpoint = _lazy_load(tmp_path / "model.pt") + model1.load_state_dict(checkpoint) assert isinstance(checkpoint["weight"], _NotYetLoadedTensor) assert type(model0.weight.data) == torch.Tensor @@ -44,12 +44,12 @@ def test_lazy_load_tensor(tmp_path): } torch.save(expected, tmp_path / "data.pt") - with _lazy_load(tmp_path / "data.pt") as loaded: - for t0, t1 in zip(expected.values(), loaded.values()): - assert isinstance(t1, _NotYetLoadedTensor) - t1_materialized = t1._load_tensor() - assert type(t0) == type(t1_materialized) - assert torch.equal(t0, t1_materialized) + loaded = _lazy_load(tmp_path / "data.pt") + for t0, t1 in zip(expected.values(), loaded.values()): + assert isinstance(t1, _NotYetLoadedTensor) + t1_materialized = t1._load_tensor() + assert type(t0) == type(t1_materialized) + assert torch.equal(t0, t1_materialized) def test_lazy_load_mixed_state(tmp_path): @@ -67,6 +67,6 @@ def test_lazy_load_mixed_state(tmp_path): model1 = nn.Linear(2, 2) optim1 = torch.optim.Adam(model0.parameters()) - with _lazy_load(tmp_path / "checkpoint.pt") as loaded_checkpoint: - model1.load_state_dict(loaded_checkpoint["model"]) - optim1.load_state_dict(loaded_checkpoint["optimizer"]) + loaded_checkpoint = _lazy_load(tmp_path / "checkpoint.pt") + model1.load_state_dict(loaded_checkpoint["model"]) + optim1.load_state_dict(loaded_checkpoint["optimizer"]) From 37276b0e17d0a7ff76e00ec8ca7ebba63e4e4650 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 14:55:13 +0200 Subject: [PATCH 09/27] load --- src/lightning/fabric/strategies/fsdp.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index db8cae681a9e2..1ebfa5b54ed99 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -39,6 +39,7 @@ from torch.optim import Optimizer from typing_extensions import TypeGuard +from fabric.utilities.load import _lazy_load from lightning.fabric.accelerators import Accelerator from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, Precision from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout @@ -67,7 +68,6 @@ _TORCH_GREATER_EQUAL_2_1, ) from lightning.fabric.utilities.init import _EmptyInit -from lightning.fabric.utilities.load import _lazy_load from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_only, rank_zero_warn from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import _PATH @@ -822,12 +822,9 @@ def _load_raw_module_state(path_or_ckpt: Union[Path, Dict[str, Any]], module: Mo from torch.distributed.fsdp import FullyShardedDataParallel as FSDP - if isinstance(path_or_ckpt, Path): - with _lazy_load(path_or_ckpt) as state_dict, FSDP.summon_full_params(module, writeback=True, rank0_only=False): - module.load_state_dict(state_dict, strict=strict) - else: - with FSDP.summon_full_params(module, writeback=True, rank0_only=False): - module.load_state_dict(path_or_ckpt, strict=strict) + state_dict = _lazy_load(path_or_ckpt) if isinstance(path_or_ckpt, Path) else path_or_ckpt + with FSDP.summon_full_params(module, writeback=True, rank0_only=False): + module.load_state_dict(state_dict, strict=strict) def _no_op() -> None: From 82397a7e663fb295797d19d14ed76e8820c1b7ce Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Jul 2023 12:56:22 +0000 Subject: [PATCH 10/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/strategies/fsdp.py | 2 +- src/lightning/fabric/utilities/load.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 1ebfa5b54ed99..f65df8da87028 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -34,12 +34,12 @@ ) import torch +from fabric.utilities.load import _lazy_load from torch import Tensor from torch.nn import Module from torch.optim import Optimizer from typing_extensions import TypeGuard -from fabric.utilities.load import _lazy_load from lightning.fabric.accelerators import Accelerator from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, Precision from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 556eaee2ee137..0eb88f3713170 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -12,7 +12,6 @@ # limitations under the License. import pickle import warnings -from contextlib import contextmanager from functools import partial from io import BytesIO from typing import Any, Callable, Dict, IO, Optional, Sequence From 87817176b93b51933a70f18c417131f01aecbc09 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 14:58:23 +0200 Subject: [PATCH 11/27] note --- src/lightning/fabric/strategies/fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 1ebfa5b54ed99..738ad24def286 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -822,6 +822,7 @@ def _load_raw_module_state(path_or_ckpt: Union[Path, Dict[str, Any]], module: Mo from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + # Use `lazy_load` instead of `torch.load` here to avoid storing a copy of the full checkpoint per rank state_dict = _lazy_load(path_or_ckpt) if isinstance(path_or_ckpt, Path) else path_or_ckpt with FSDP.summon_full_params(module, writeback=True, rank0_only=False): module.load_state_dict(state_dict, strict=strict) From 7313e0dc66ca0161b8ad83ce14fcc87254859592 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 15:03:13 +0200 Subject: [PATCH 12/27] materialize --- src/lightning/fabric/strategies/fsdp.py | 7 +++++-- src/lightning/fabric/utilities/load.py | 6 ++++++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 738ad24def286..23777766f1237 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -39,7 +39,7 @@ from torch.optim import Optimizer from typing_extensions import TypeGuard -from fabric.utilities.load import _lazy_load +from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors from lightning.fabric.accelerators import Accelerator from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment, Precision from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout @@ -574,7 +574,7 @@ def load_checkpoint( return metadata if _is_full_checkpoint(path): - checkpoint = torch.load(path, map_location="cpu") + checkpoint = _lazy_load(path) _load_raw_module_state(checkpoint.pop(module_key), module=module, strict=strict) if isinstance(state, Module): @@ -602,6 +602,9 @@ def load_checkpoint( requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) + # Materialize lazy tensors if there are any left in the checkpoint + _materialize_tensors(checkpoint) + # Load metadata (anything not a module or optimizer) for key in requested_metadata_keys: if key not in checkpoint: diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 556eaee2ee137..8cdc880c819c8 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -164,3 +164,9 @@ def _lazy_load(filename: _PATH) -> Any: with BytesIO(file_reader.get_record("data.pkl")) as pkl: mup = _LazyLoadingUnpickler(pkl, file_reader) return mup.load() + + +def _materialize_tensors(checkpoint: Dict[str, Any]) -> None: + for k, v in checkpoint.items(): + if isinstance(v, _NotYetLoadedTensor): + checkpoint[k] = v._load_tensor() From eb409b1c44783d3e8a9cba5954b11d4804f972f4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 15:23:43 +0200 Subject: [PATCH 13/27] update --- src/lightning/fabric/strategies/fsdp.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index f62e3ca6163f4..a10944e809c7d 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -580,6 +580,10 @@ def load_checkpoint( if isinstance(state, Module): return {} + # Materialize lazy tensors if there are any left in the checkpoint + # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues + _materialize_tensors(checkpoint) + # Load optimizer states for optim_key, optim in optimizers.items(): # rank0_only should be false because we need to load the optimizer state on all ranks @@ -602,9 +606,6 @@ def load_checkpoint( requested_metadata_keys = state.keys() - modules.keys() - optimizers.keys() _validate_keys_for_strict_loading(requested_metadata_keys, checkpoint.keys(), strict=strict) - # Materialize lazy tensors if there are any left in the checkpoint - _materialize_tensors(checkpoint) - # Load metadata (anything not a module or optimizer) for key in requested_metadata_keys: if key not in checkpoint: From 10a59d934ec0c66a81158e0beda8fcec65420885 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 15:59:32 +0200 Subject: [PATCH 14/27] recursive materialize --- src/lightning/fabric/utilities/load.py | 10 +++++---- tests/tests_fabric/utilities/test_load.py | 27 +++++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index e9a4f504bc9b7..e880eddacca15 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -18,6 +18,7 @@ import torch import torch.utils._device +from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch.storage import TypedStorage, UntypedStorage @@ -165,7 +166,8 @@ def _lazy_load(filename: _PATH) -> Any: return mup.load() -def _materialize_tensors(checkpoint: Dict[str, Any]) -> None: - for k, v in checkpoint.items(): - if isinstance(v, _NotYetLoadedTensor): - checkpoint[k] = v._load_tensor() +def _materialize_tensors(collection: Any) -> Any: + def _load_tensor(t: _NotYetLoadedTensor) -> Tensor: + return t._load_tensor() + + return apply_to_collection(collection, dtype=_NotYetLoadedTensor, function=_load_tensor) diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 443a54afd181e..6cb278df4cea2 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from lightning.fabric.utilities.load import _lazy_load, _NotYetLoadedTensor +from lightning.fabric.utilities.load import _lazy_load, _NotYetLoadedTensor, _materialize_tensors def test_lazy_load_module(tmp_path): @@ -47,7 +47,7 @@ def test_lazy_load_tensor(tmp_path): loaded = _lazy_load(tmp_path / "data.pt") for t0, t1 in zip(expected.values(), loaded.values()): assert isinstance(t1, _NotYetLoadedTensor) - t1_materialized = t1._load_tensor() + t1_materialized = _materialize_tensors(t1) assert type(t0) == type(t1_materialized) assert torch.equal(t0, t1_materialized) @@ -70,3 +70,26 @@ def test_lazy_load_mixed_state(tmp_path): loaded_checkpoint = _lazy_load(tmp_path / "checkpoint.pt") model1.load_state_dict(loaded_checkpoint["model"]) optim1.load_state_dict(loaded_checkpoint["optimizer"]) + + +def test_materialize_tensors(tmp_path): + # Single tensor + tensor = torch.tensor([1, 2]) + torch.save(tensor, tmp_path / "tensor.pt") + loaded = _lazy_load(tmp_path / "tensor.pt") + materialized = _materialize_tensors(loaded) + assert torch.equal(materialized, tensor) + assert type(tensor) == type(materialized) + + # Collection of tensors + collection = { + "tensor": torch.tensor([1, 2]), + "nested": {"int": 1, "list": [torch.tensor([3.]), torch.tensor([4])]} + } + torch.save(collection, tmp_path / "collection.pt") + loaded = _lazy_load(tmp_path / "collection.pt") + materialized = _materialize_tensors(loaded) + assert torch.equal(materialized["tensor"], collection["tensor"]) + assert torch.equal(materialized["nested"]["list"][0], collection["nested"]["list"][0]) + assert torch.equal(materialized["nested"]["list"][1], collection["nested"]["list"][1]) + assert materialized["nested"]["int"] == 1 From 464f13135ab3bf06627f72458e63f30729c4d0f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:00:55 +0000 Subject: [PATCH 15/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_fabric/utilities/test_load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 6cb278df4cea2..ee8d2809faa11 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -14,7 +14,7 @@ import torch import torch.nn as nn -from lightning.fabric.utilities.load import _lazy_load, _NotYetLoadedTensor, _materialize_tensors +from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _NotYetLoadedTensor def test_lazy_load_module(tmp_path): @@ -84,7 +84,7 @@ def test_materialize_tensors(tmp_path): # Collection of tensors collection = { "tensor": torch.tensor([1, 2]), - "nested": {"int": 1, "list": [torch.tensor([3.]), torch.tensor([4])]} + "nested": {"int": 1, "list": [torch.tensor([3.0]), torch.tensor([4])]}, } torch.save(collection, tmp_path / "collection.pt") loaded = _lazy_load(tmp_path / "collection.pt") From f19f09ff526abb2b61e927fb0ed1a5e53da86203 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 25 Jul 2023 10:02:06 -0400 Subject: [PATCH 16/27] fix --- src/lightning/fabric/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index a10944e809c7d..d02966eb73da8 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -582,7 +582,7 @@ def load_checkpoint( # Materialize lazy tensors if there are any left in the checkpoint # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues - _materialize_tensors(checkpoint) + checkpoint = _materialize_tensors(checkpoint) # Load optimizer states for optim_key, optim in optimizers.items(): From df3435962ba0432a21d6ea97ca1333357f8ac3cf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 16:25:26 +0200 Subject: [PATCH 17/27] clean up --- src/lightning/fabric/utilities/load.py | 74 +++++++++++++++++--------- 1 file changed, 48 insertions(+), 26 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index e880eddacca15..66acc127f4cd9 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -14,12 +14,14 @@ import warnings from functools import partial from io import BytesIO -from typing import Any, Callable, Dict, IO, Optional, Sequence +from typing import Any, Callable, Dict, IO, Optional, Sequence, OrderedDict, Union import torch import torch.utils._device from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor +from torch._C import _TensorMeta +from torch.nn import Parameter from torch.storage import TypedStorage, UntypedStorage from lightning.fabric.utilities.types import _PATH @@ -27,43 +29,73 @@ # Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann class _NotYetLoadedTensor: - def __init__(self, metatensor, archiveinfo, storageinfo, rebuild_args): + def __init__( + self, + metatensor: Tensor, + archiveinfo: "_LazyLoadingUnpickler", + storageinfo: tuple, + rebuild_args: tuple, + ) -> None: self.metatensor = metatensor self.archiveinfo = archiveinfo self.storageinfo = storageinfo self.rebuild_args = rebuild_args @classmethod - def rebuild_from_type_v2(cls, func, new_type, args, state, *, archiveinfo=None): + def rebuild_from_type_v2( + cls, + func: Callable, + new_type: _TensorMeta, + args: tuple, + state: dict, + *, + archiveinfo: Optional[pickle.Unpickler] = None, + ) -> Any: ret = func(*args) if isinstance(ret, _NotYetLoadedTensor): old_lt = ret._load_tensor - def _load_tensor(): + def _load_tensor() -> Any: t = old_lt() return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) - ret._load_tensor = _load_tensor + ret._load_tensor = _load_tensor # type: ignore[assignment] return ret return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) @classmethod - def rebuild_parameter(cls, data, requires_grad, backward_hooks, *, archiveinfo=None): + def rebuild_parameter( + cls, + data: Any, + requires_grad: bool, + backward_hooks: OrderedDict, + *, + archiveinfo: Optional[pickle.Unpickler] = None, + ) -> Union[Tensor, "_NotYetLoadedTensor"]: if isinstance(data, _NotYetLoadedTensor): old_lt = data._load_tensor - def _load_tensor(): + def _load_tensor() -> Parameter: t = old_lt() return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) - data._load_tensor = _load_tensor + data._load_tensor = _load_tensor # type: ignore[assignment] return data return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) @classmethod def rebuild_tensor_v2( - cls, storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata=None, *, archiveinfo=None - ): + cls, + storage: TypedStorage, + storage_offset: int, + size: tuple, + stride: tuple, + requires_grad: bool, + backward_hooks: OrderedDict, + metadata: Optional[Any] = None, + *, + archiveinfo: "_LazyLoadingUnpickler", + ) -> "_NotYetLoadedTensor": rebuild_args = (storage_offset, size, stride, requires_grad, backward_hooks, metadata) metatensor = torch._utils._rebuild_tensor_v2( storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata @@ -75,16 +107,14 @@ def _load_tensor(self) -> Tensor: name, storage_cls, fn, device, size = self.storageinfo dtype = self.metatensor.dtype - uts = ( - self.archiveinfo.file_reader.get_storage_from_record( - f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage - ) - ._typed_storage() - ._untyped_storage + storage = self.archiveinfo.file_reader.get_storage_from_record( + f"data/{fn}", size * torch._utils._element_size(dtype), UntypedStorage ) + uts = storage._typed_storage()._untyped_storage + with warnings.catch_warnings(): warnings.simplefilter("ignore") - storage = TypedStorage(wrap_storage=uts, dtype=self.metatensor.dtype, _internal=True) + storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True) return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) @classmethod @@ -142,7 +172,7 @@ def find_class(self, module: str, name: str) -> Any: return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) return super().find_class(module, name) - def persistent_load(self, pid) -> TypedStorage: + def persistent_load(self, pid: tuple) -> TypedStorage: name, cls, fn, device, size = pid with warnings.catch_warnings(): warnings.simplefilter("ignore") # TODO: needed? @@ -151,14 +181,6 @@ def persistent_load(self, pid) -> TypedStorage: return storage -# @contextmanager -# def _lazy_load(filename: _PATH) -> Any: -# file_reader = torch.PyTorchFileReader(str(filename)) -# with BytesIO(file_reader.get_record("data.pkl")) as pkl: -# mup = _LazyLoadingUnpickler(pkl, file_reader) -# yield mup.load() - - def _lazy_load(filename: _PATH) -> Any: file_reader = torch.PyTorchFileReader(str(filename)) with BytesIO(file_reader.get_record("data.pkl")) as pkl: From 08692942ef7c0ff9bc2dd694e885246f4f18d0e7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Tue, 25 Jul 2023 16:28:53 +0200 Subject: [PATCH 18/27] update --- src/lightning/fabric/utilities/load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 66acc127f4cd9..a2ac6c2ce2665 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -49,7 +49,7 @@ def rebuild_from_type_v2( args: tuple, state: dict, *, - archiveinfo: Optional[pickle.Unpickler] = None, + archiveinfo: Optional["_LazyLoadingUnpickler"] = None, ) -> Any: ret = func(*args) if isinstance(ret, _NotYetLoadedTensor): @@ -70,7 +70,7 @@ def rebuild_parameter( requires_grad: bool, backward_hooks: OrderedDict, *, - archiveinfo: Optional[pickle.Unpickler] = None, + archiveinfo: Optional["_LazyLoadingUnpickler"] = None, ) -> Union[Tensor, "_NotYetLoadedTensor"]: if isinstance(data, _NotYetLoadedTensor): old_lt = data._load_tensor From f9e1a845d657f8a936e5b8c7779f0fe28e7295a3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 25 Jul 2023 14:29:11 +0000 Subject: [PATCH 19/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index a2ac6c2ce2665..6b9317bb2ddef 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -14,7 +14,7 @@ import warnings from functools import partial from io import BytesIO -from typing import Any, Callable, Dict, IO, Optional, Sequence, OrderedDict, Union +from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, Union import torch import torch.utils._device From 19652837b652d8b8cce1be5559ed440df7798f72 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Jul 2023 18:12:46 +0200 Subject: [PATCH 20/27] remove unused import --- src/lightning/fabric/utilities/load.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 6b9317bb2ddef..01e16028fb074 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -17,7 +17,6 @@ from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, Union import torch -import torch.utils._device from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch._C import _TensorMeta From 9cc4e44e2ce09742621441798aab67d28af8adaf Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Jul 2023 21:23:58 +0200 Subject: [PATCH 21/27] pytorch 2.0 only --- src/lightning/fabric/strategies/fsdp.py | 9 +++++---- src/lightning/fabric/utilities/load.py | 17 +++++++++++++---- tests/tests_fabric/utilities/test_load.py | 5 +++++ 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index d02966eb73da8..d283e42c83cfd 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -574,15 +574,16 @@ def load_checkpoint( return metadata if _is_full_checkpoint(path): - checkpoint = _lazy_load(path) + checkpoint = _lazy_load(path) if _TORCH_GREATER_EQUAL_2_0 else torch.load(path, map_location="cpu") _load_raw_module_state(checkpoint.pop(module_key), module=module, strict=strict) if isinstance(state, Module): return {} - # Materialize lazy tensors if there are any left in the checkpoint - # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues - checkpoint = _materialize_tensors(checkpoint) + if _TORCH_GREATER_EQUAL_2_0: + # Materialize lazy tensors if there are any left in the checkpoint + # The `torch.Optimizer.load_state_dict` method can't load lazy tensors because of deepcopy pickle issues + checkpoint = _materialize_tensors(checkpoint) # Load optimizer states for optim_key, optim in optimizers.items(): diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 01e16028fb074..d8f63da0ef693 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -14,16 +14,19 @@ import warnings from functools import partial from io import BytesIO -from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, Union +from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, Union, TYPE_CHECKING import torch from lightning_utilities.core.apply_func import apply_to_collection from torch import Tensor from torch._C import _TensorMeta from torch.nn import Parameter -from torch.storage import TypedStorage, UntypedStorage from lightning.fabric.utilities.types import _PATH +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 + +if TYPE_CHECKING: + from torch.storage import TypedStorage, UntypedStorage # Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann @@ -85,7 +88,7 @@ def _load_tensor() -> Parameter: @classmethod def rebuild_tensor_v2( cls, - storage: TypedStorage, + storage: "TypedStorage", storage_offset: int, size: tuple, stride: tuple, @@ -103,6 +106,8 @@ def rebuild_tensor_v2( return _NotYetLoadedTensor(metatensor, archiveinfo, storageinfo, rebuild_args) def _load_tensor(self) -> Tensor: + from torch.storage import TypedStorage, UntypedStorage + name, storage_cls, fn, device, size = self.storageinfo dtype = self.metatensor.dtype @@ -171,7 +176,9 @@ def find_class(self, module: str, name: str) -> Any: return partial(_NotYetLoadedTensor.rebuild_parameter, archiveinfo=self) return super().find_class(module, name) - def persistent_load(self, pid: tuple) -> TypedStorage: + def persistent_load(self, pid: tuple) -> "TypedStorage": + from torch.storage import TypedStorage + name, cls, fn, device, size = pid with warnings.catch_warnings(): warnings.simplefilter("ignore") # TODO: needed? @@ -181,6 +188,8 @@ def persistent_load(self, pid: tuple) -> TypedStorage: def _lazy_load(filename: _PATH) -> Any: + if not _TORCH_GREATER_EQUAL_2_0: + raise NotImplementedError(f"Lazy-loading is only supported with PyTorch >= 2.0.") file_reader = torch.PyTorchFileReader(str(filename)) with BytesIO(file_reader.get_record("data.pkl")) as pkl: mup = _LazyLoadingUnpickler(pkl, file_reader) diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index ee8d2809faa11..49832caa5abd5 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -15,8 +15,10 @@ import torch.nn as nn from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors, _NotYetLoadedTensor +from tests_fabric.helpers.runif import RunIf +@RunIf(min_torch="2.0.0") def test_lazy_load_module(tmp_path): model0 = nn.Linear(2, 2) torch.save(model0.state_dict(), tmp_path / "model.pt") @@ -35,6 +37,7 @@ class ATensor(torch.Tensor): pass +@RunIf(min_torch="2.0.0") def test_lazy_load_tensor(tmp_path): """Test that lazy load can handle different classes of tensors.""" expected = { @@ -52,6 +55,7 @@ def test_lazy_load_tensor(tmp_path): assert torch.equal(t0, t1_materialized) +@RunIf(min_torch="2.0.0") def test_lazy_load_mixed_state(tmp_path): model0 = nn.Linear(2, 2) optim0 = torch.optim.Adam(model0.parameters()) @@ -72,6 +76,7 @@ def test_lazy_load_mixed_state(tmp_path): optim1.load_state_dict(loaded_checkpoint["optimizer"]) +@RunIf(min_torch="2.0.0") def test_materialize_tensors(tmp_path): # Single tensor tensor = torch.tensor([1, 2]) From ae74fbbddb722d13be995010c04398a328110e84 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:25:13 +0000 Subject: [PATCH 22/27] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/fabric/utilities/load.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index d8f63da0ef693..709d93644d4b3 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -14,7 +14,7 @@ import warnings from functools import partial from io import BytesIO -from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, Union, TYPE_CHECKING +from typing import Any, Callable, Dict, IO, Optional, OrderedDict, Sequence, TYPE_CHECKING, Union import torch from lightning_utilities.core.apply_func import apply_to_collection @@ -22,11 +22,11 @@ from torch._C import _TensorMeta from torch.nn import Parameter -from lightning.fabric.utilities.types import _PATH from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0 +from lightning.fabric.utilities.types import _PATH if TYPE_CHECKING: - from torch.storage import TypedStorage, UntypedStorage + from torch.storage import TypedStorage # Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann @@ -189,7 +189,7 @@ def persistent_load(self, pid: tuple) -> "TypedStorage": def _lazy_load(filename: _PATH) -> Any: if not _TORCH_GREATER_EQUAL_2_0: - raise NotImplementedError(f"Lazy-loading is only supported with PyTorch >= 2.0.") + raise NotImplementedError("Lazy-loading is only supported with PyTorch >= 2.0.") file_reader = torch.PyTorchFileReader(str(filename)) with BytesIO(file_reader.get_record("data.pkl")) as pkl: mup = _LazyLoadingUnpickler(pkl, file_reader) From a592445631a138c179c1d7d404f981bd66ee973b Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Jul 2023 23:03:02 +0200 Subject: [PATCH 23/27] mypy --- src/lightning/fabric/utilities/load.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 709d93644d4b3..b37cf01154571 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -61,7 +61,7 @@ def _load_tensor() -> Any: t = old_lt() return torch._tensor._rebuild_from_type_v2(lambda: t, new_type, (), state) - ret._load_tensor = _load_tensor # type: ignore[assignment] + ret._load_tensor = _load_tensor # type: ignore[method-assign] return ret return torch._tensor._rebuild_from_type_v2(func, new_type, args, state) @@ -81,7 +81,7 @@ def _load_tensor() -> Parameter: t = old_lt() return torch._utils._rebuild_parameter(t, requires_grad, backward_hooks) - data._load_tensor = _load_tensor # type: ignore[assignment] + data._load_tensor = _load_tensor # type: ignore[method-assign] return data return torch._utils._rebuild_parameter(data, requires_grad, backward_hooks) From 7b5c501ac0c000637ad896120c202dcae9773801 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Jul 2023 23:07:26 +0200 Subject: [PATCH 24/27] better attribute error message --- src/lightning/fabric/utilities/load.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index b37cf01154571..a16aa50ed0170 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -151,12 +151,11 @@ def __getattr__(self, name: str) -> Any: }: return getattr(self.metatensor, name) - # TODO: needed for us? - # materializing with contiguous is needed for quantization + # Materialization with contiguous is needed for quantization (see lit-gpt) if name in {"contiguous"}: return getattr(self._load_tensor(), name) - raise AttributeError(f"{type(self)} does not have {name}") + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") def __repr__(self) -> str: return f"{self.__class__.__name__}({repr(self.metatensor)})" From b23f728f757284701b7224faf612d50c70953fcc Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Jul 2023 23:09:42 +0200 Subject: [PATCH 25/27] add comment about warnings --- src/lightning/fabric/utilities/load.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index a16aa50ed0170..05dc240039fb7 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -117,6 +117,7 @@ def _load_tensor(self) -> Tensor: uts = storage._typed_storage()._untyped_storage with warnings.catch_warnings(): + # The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now warnings.simplefilter("ignore") storage = TypedStorage(wrap_storage=uts, dtype=dtype, _internal=True) return torch._utils._rebuild_tensor_v2(storage, *self.rebuild_args) @@ -180,7 +181,8 @@ def persistent_load(self, pid: tuple) -> "TypedStorage": name, cls, fn, device, size = pid with warnings.catch_warnings(): - warnings.simplefilter("ignore") # TODO: needed? + # The TypedStorage APIs have heavy deprecations in torch, suppress all these warnings for now + warnings.simplefilter("ignore") storage = TypedStorage(dtype=cls().dtype, device="meta") storage.archiveinfo = pid return storage From b83f4a099250c18415fa0e8bda433f8aae7493b4 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Jul 2023 23:11:34 +0200 Subject: [PATCH 26/27] Add note about original author --- src/lightning/fabric/utilities/load.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/fabric/utilities/load.py b/src/lightning/fabric/utilities/load.py index 05dc240039fb7..e2065f10bee37 100644 --- a/src/lightning/fabric/utilities/load.py +++ b/src/lightning/fabric/utilities/load.py @@ -162,6 +162,7 @@ def __repr__(self) -> str: return f"{self.__class__.__name__}({repr(self.metatensor)})" +# Modified from https://github.com/lernapparat/torchhacks by Thomas Viehmann class _LazyLoadingUnpickler(pickle.Unpickler): def __init__(self, file: IO, file_reader: torch.PyTorchFileReader) -> None: super().__init__(file) From bf78b234183c21efa7cc6a168423a4faf55cbea2 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 26 Jul 2023 23:41:55 +0200 Subject: [PATCH 27/27] avoid jsonargparse deprecation message --- tests/parity_fabric/test_parity_ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/parity_fabric/test_parity_ddp.py b/tests/parity_fabric/test_parity_ddp.py index 973972609edd0..d30d2b6233886 100644 --- a/tests/parity_fabric/test_parity_ddp.py +++ b/tests/parity_fabric/test_parity_ddp.py @@ -161,6 +161,6 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float if __name__ == "__main__": - from jsonargparse.cli import CLI + from jsonargparse import CLI CLI(run_parity_test)