Skip to content

Commit 75f83c7

Browse files
authored
Merge branch 'master' into 19223_num_workers_warning
2 parents 03c274a + 4d6acda commit 75f83c7

File tree

6 files changed

+43
-56
lines changed

6 files changed

+43
-56
lines changed

requirements/docs.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
sphinx >5.0, <6.0
22
myst-parser >=0.18.1, <3.0.0
33
nbsphinx >=0.8.5, <=0.9.2
4+
nbconvert <7.14 # temporary fix for https://github.com/jupyter/nbconvert/issues/2092
45
pandoc >=1.0, <=2.3
56
docutils >=0.16, <0.21
67
sphinxcontrib-fulltoc >=1.0, <=1.2.0

src/lightning/data/streaming/reader.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868

6969
# FIXME: This should be divided by the number of nodes to provide a more granular support with scaling out
7070
self._delete_chunks_when_processed = self._config.num_bytes > max_cache_size if max_cache_size else False
71+
self._has_exited = False
7172

7273
def download(self, chunk_indexes: List[int]) -> None:
7374
"""Receive the list of the chunk indices to download for the current epoch."""
@@ -111,7 +112,7 @@ def _maybe_delete_chunks(self) -> None:
111112

112113
def _can_delete_chunk(self) -> bool:
113114
if self._delete_chunks_when_processed:
114-
return self._pre_download_counter == self._max_pre_download - 1
115+
return self._pre_download_counter >= self._max_pre_download - 1
115116
return self._max_cache_size is not None and _get_folder_size(self._parent_cache_dir) >= self._max_cache_size
116117

117118
def _pre_load_chunk(self, chunk_index: int) -> None:
@@ -120,9 +121,10 @@ def _pre_load_chunk(self, chunk_index: int) -> None:
120121

121122
def run(self) -> None:
122123
while True:
123-
if self._pre_download_counter <= self._max_pre_download:
124+
if self._pre_download_counter < self._max_pre_download:
124125
chunk_index = _get_from_queue(self._to_download_queue)
125126
if chunk_index == _END_TOKEN:
127+
self._has_exited = True
126128
return
127129

128130
if chunk_index is not None:

src/lightning/fabric/utilities/rank_zero.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,12 @@
3030
)
3131
from typing_extensions import ParamSpec
3232

33-
import lightning.fabric
3433
from lightning.fabric.utilities.imports import _UTILITIES_GREATER_EQUAL_0_10
3534

3635
rank_zero_module.log = logging.getLogger(__name__)
3736

3837

39-
def _get_rank(
40-
strategy: Optional["lightning.fabric.strategies.Strategy"] = None,
41-
) -> Optional[int]:
42-
if strategy is not None:
43-
return strategy.global_rank
38+
def _get_rank() -> Optional[int]:
4439
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
4540
# therefore LOCAL_RANK needs to be checked first
4641
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")

src/lightning/pytorch/callbacks/early_stopping.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from typing_extensions import override
2727

2828
import lightning.pytorch as pl
29-
from lightning.fabric.utilities.rank_zero import _get_rank
3029
from lightning.pytorch.callbacks.callback import Callback
3130
from lightning.pytorch.utilities.exceptions import MisconfigurationException
3231
from lightning.pytorch.utilities.rank_zero import rank_prefixed_message, rank_zero_warn
@@ -265,12 +264,8 @@ def _improvement_message(self, current: Tensor) -> str:
265264
return msg
266265

267266
@staticmethod
268-
def _log_info(trainer: Optional["pl.Trainer"], message: str, log_rank_zero_only: bool) -> None:
269-
rank = _get_rank(
270-
strategy=(trainer.strategy if trainer is not None else None), # type: ignore[arg-type]
271-
)
272-
if trainer is not None and trainer.world_size <= 1:
273-
rank = None
267+
def _log_info(trainer: "pl.Trainer", message: str, log_rank_zero_only: bool) -> None:
268+
rank = trainer.global_rank if trainer.world_size > 1 else None
274269
message = rank_prefixed_message(message, rank)
275270
if rank is None or not log_rank_zero_only or rank == 0:
276271
log.info(message)
Lines changed: 33 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import os
22
import shutil
3+
from time import sleep
34

45
import numpy as np
6+
from lightning.data.streaming import reader
57
from lightning.data.streaming.cache import Cache
68
from lightning.data.streaming.config import ChunkedIndex
79
from lightning.data.streaming.item_loader import PyTreeLoader
8-
from lightning.data.streaming.reader import PrepareChunksThread, _get_folder_size
10+
from lightning.data.streaming.reader import _END_TOKEN, PrepareChunksThread, _get_folder_size
911
from lightning_cloud.resolver import Dir
1012

1113

@@ -36,40 +38,11 @@ def test_reader_chunk_removal(tmpdir):
3638
shutil.rmtree(cache_dir)
3739
os.makedirs(cache_dir, exist_ok=True)
3840

39-
generated = []
4041
for i in range(25):
41-
generated.append([i, len(os.listdir(cache_dir))])
42+
assert len(os.listdir(cache_dir)) <= 3
4243
index = ChunkedIndex(i, cache._get_chunk_index_from_index(i), is_last_index=i == 24)
4344
assert cache[index] == i
4445

45-
assert generated == [
46-
[0, 0],
47-
[1, 2],
48-
[2, 2],
49-
[3, 3],
50-
[4, 3],
51-
[5, 3],
52-
[6, 3],
53-
[7, 3],
54-
[8, 3],
55-
[9, 3],
56-
[10, 3],
57-
[11, 3],
58-
[12, 3],
59-
[13, 3],
60-
[14, 3],
61-
[15, 3],
62-
[16, 3],
63-
[17, 3],
64-
[18, 3],
65-
[19, 3],
66-
[20, 3],
67-
[21, 3],
68-
[22, 3],
69-
[23, 3],
70-
[24, 3],
71-
]
72-
7346
assert len(os.listdir(cache_dir)) == 3
7447

7548

@@ -82,7 +55,9 @@ def test_get_folder_size(tmpdir):
8255
assert _get_folder_size(tmpdir) == 928 * 2
8356

8457

85-
def test_prepare_chunks_thread(tmpdir):
58+
def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch):
59+
monkeypatch.setattr(reader, "_LONG_DEFAULT_TIMEOUT", 0.1)
60+
8661
cache_dir = os.path.join(tmpdir, "cache_dir")
8762
os.makedirs(cache_dir, exist_ok=True)
8863
cache = Cache(input_dir=cache_dir, chunk_size=2, max_cache_size=28020)
@@ -95,8 +70,32 @@ def test_prepare_chunks_thread(tmpdir):
9570

9671
cache._reader._try_load_config()
9772

98-
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=1)
99-
assert thread._delete_chunks_when_processed
73+
assert len(os.listdir(cache_dir)) == 14
10074

10175
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=10000)
10276
assert not thread._delete_chunks_when_processed
77+
78+
thread = PrepareChunksThread(cache._reader.config, item_loader=PyTreeLoader(), max_cache_size=1)
79+
assert thread._delete_chunks_when_processed
80+
81+
thread.start()
82+
83+
assert thread._pre_download_counter == 0
84+
85+
thread.download([0, 1, 2, 3, 4, 5, _END_TOKEN])
86+
87+
while thread._pre_download_counter == 0:
88+
sleep(0.01)
89+
90+
assert not thread._has_exited
91+
92+
for i in range(5):
93+
thread.delete([i])
94+
while len(os.listdir(cache_dir)) != 14 - (i + 1):
95+
sleep(0.01)
96+
97+
assert thread._pre_download_counter <= 2
98+
99+
assert len(os.listdir(cache_dir)) == 9
100+
assert thread._has_exited
101+
thread.join()

tests/tests_pytorch/callbacks/test_early_stopping.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,6 @@ def test_early_stopping_squeezes():
480480
es_mock.assert_called_once_with(torch.tensor(0))
481481

482482

483-
@pytest.mark.parametrize("trainer", [Trainer(), None])
484483
@pytest.mark.parametrize(
485484
("log_rank_zero_only", "world_size", "global_rank", "expected_log"),
486485
[
@@ -492,15 +491,11 @@ def test_early_stopping_squeezes():
492491
(True, 2, 1, None),
493492
],
494493
)
495-
def test_early_stopping_log_info(trainer, log_rank_zero_only, world_size, global_rank, expected_log):
494+
def test_early_stopping_log_info(log_rank_zero_only, world_size, global_rank, expected_log):
496495
"""Checks if log.info() gets called with expected message when used within EarlyStopping."""
497496
# set the global_rank and world_size if trainer is not None
498497
# or else always expect the simple logging message
499-
if trainer:
500-
trainer.strategy.global_rank = global_rank
501-
trainer.strategy.world_size = world_size
502-
else:
503-
expected_log = "bar"
498+
trainer = Mock(global_rank=global_rank, world_size=world_size)
504499

505500
with mock.patch("lightning.pytorch.callbacks.early_stopping.log.info") as log_mock:
506501
EarlyStopping._log_info(trainer, "bar", log_rank_zero_only)

0 commit comments

Comments
 (0)