Skip to content

Commit b468885

Browse files
jjh42Bordapre-commit-ci[bot]bhimrazy
committed
Make asyncio checkpointing work if validate/fit is called more than once (#20952)
* Make asyncio checkpointing work if validate/fit is called more than once. * Apply suggestions from code review * Add assertion to ensure executor is initialized before saving checkpoint * update --------- Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka B <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: bhimrazy <[email protected]> (cherry picked from commit ff64a92)
1 parent 5ac8273 commit b468885

File tree

3 files changed

+39
-9
lines changed

3 files changed

+39
-9
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
- Fixed `AsyncCheckpointIO` snapshots tensors to avoid race with parameter mutation ([#21079](https://github.com/Lightning-AI/pytorch-lightning/pull/21079))
2929

3030

31+
- Fixed `AsyncCheckpointIO` threadpool exception if calling fit or validate more than one ([#20952](https://github.com/Lightning-AI/pytorch-lightning/pull/20952))
32+
33+
3134
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))
3235

36+
3337
---
3438

3539
## [2.5.3] - 2025-08-13

src/lightning/pytorch/plugins/io/async_plugin.py

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,17 @@
1313
# limitations under the License.
1414

1515
from concurrent.futures import ThreadPoolExecutor
16-
from typing import Any, Optional
16+
from typing import TYPE_CHECKING, Any, Optional
1717

1818
import torch
1919
from lightning_utilities.core.apply_func import apply_to_collection
2020
from typing_extensions import override
2121

22-
from lightning.fabric.plugins import CheckpointIO
2322
from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO
2423

24+
if TYPE_CHECKING:
25+
from lightning.fabric.plugins import CheckpointIO
26+
2527

2628
class AsyncCheckpointIO(_WrappingCheckpointIO):
2729
"""``AsyncCheckpointIO`` enables saving the checkpoints asynchronously in a thread.
@@ -33,20 +35,30 @@ class AsyncCheckpointIO(_WrappingCheckpointIO):
3335
3436
"""
3537

38+
_executor: Optional[ThreadPoolExecutor]
39+
_error: Optional[BaseException]
40+
3641
def __init__(self, checkpoint_io: Optional["CheckpointIO"] = None) -> None:
3742
super().__init__(checkpoint_io)
43+
self._executor = None
44+
self._error = None
45+
46+
# CheckpointIO doesn't have a setup method so we have to do something like.
47+
def _ensure_setup(self) -> None:
48+
"""Ensures that the executor is setup.
3849
39-
self._executor = ThreadPoolExecutor(max_workers=1)
40-
self._error: Optional[BaseException] = None
50+
We can't do setup in __init__ because if train or validate is called more than once, the teardown method deletes
51+
the executor.
52+
53+
"""
54+
if self._executor is None:
55+
self._executor = ThreadPoolExecutor(max_workers=1)
4156

4257
@override
4358
def save_checkpoint(self, *args: Any, **kwargs: Any) -> None:
4459
"""Uses the ``ThreadPoolExecutor`` to save the checkpoints using the base ``checkpoint_io``."""
4560

46-
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
47-
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
48-
# detach to avoid autograd history and clone to take a point-in-time copy
49-
return t.detach().clone()
61+
self._ensure_setup()
5062

5163
# rebuild args/kwargs with a cloned checkpoint (supports positional or kw form)
5264
if "checkpoint" in kwargs:
@@ -61,6 +73,7 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
6173
except BaseException as ex:
6274
self._error = ex
6375

76+
assert self._executor is not None
6477
self._executor.submit(_save_checkpoint, *args, **kwargs)
6578

6679
# if an error was raised between the previous time `save_checkpoint`` was called and now,
@@ -71,8 +84,17 @@ def _save_checkpoint(*args: Any, **kwargs: Any) -> None:
7184
@override
7285
def teardown(self) -> None:
7386
"""This method is called to close the threads."""
74-
self._executor.shutdown(wait=True)
87+
if self._executor is not None:
88+
self._executor.shutdown(wait=True)
89+
self._executor = None
7590

7691
# if an error was raised anytime in any of the `executor.submit` calls
7792
if self._error:
7893
raise self._error
94+
95+
96+
# snapshot the checkpoint payload on the caller thread to avoid races with parameter mutation
97+
def _clone_tensor(t: torch.Tensor) -> torch.Tensor:
98+
"""Clones a tensor on the caller thread."""
99+
# detach to avoid autograd history and clone to take a point-in-time copy
100+
return t.detach().clone()

tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ def on_fit_start(self):
127127
enable_progress_bar=False,
128128
enable_model_summary=False,
129129
)
130+
131+
# We add a validate step to test that async works when fit or validate is called multiple times.
132+
trainer.validate(model)
133+
130134
trainer.fit(model)
131135

132136
assert checkpoint_plugin.save_checkpoint.call_count == 3

0 commit comments

Comments
 (0)