Skip to content

Commit 89f9e9b

Browse files
tchatonthomas
authored andcommitted
Add disk usage check before downloading files (#19041)
Co-authored-by: thomas <[email protected]> (cherry picked from commit d3df127)
1 parent 6799d59 commit 89f9e9b

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

.github/workflows/ci-tests-data.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ jobs:
9595
- name: Testing Data
9696
working-directory: tests/tests_data
9797
# NOTE: do not include coverage report here, see: https://github.com/nedbat/coveragepy/issues/1003
98+
timeout-minutes: 10
9899
run: |
99100
python -m coverage run --source lightning \
100101
-m pytest -v --timeout=60 --durations=60

src/lightning/data/streaming/data_processor.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
import os
4+
import shutil
45
import signal
56
import tempfile
67
import traceback
@@ -10,7 +11,6 @@
1011
from datetime import datetime
1112
from multiprocessing import Process, Queue
1213
from queue import Empty
13-
from shutil import copyfile, rmtree
1414
from time import sleep, time
1515
from typing import Any, Dict, List, Optional, Tuple, TypeVar, Union
1616
from urllib import parse
@@ -101,6 +101,16 @@ def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: in
101101
raise e
102102

103103

104+
def _wait_for_disk_usage_higher_than_threshold(input_dir: str, threshold_in_gb: int = 25, sleep_time: int = 3) -> None:
105+
usage = shutil.disk_usage(input_dir)
106+
107+
while (usage.free / 1000 / 1000 / 1000) <= threshold_in_gb:
108+
sleep(sleep_time)
109+
usage = shutil.disk_usage(input_dir)
110+
111+
return
112+
113+
104114
def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue_out: Queue) -> None:
105115
"""This function is used to download data from a remote directory to a cache directory to optimise reading."""
106116
s3 = S3Client()
@@ -123,7 +133,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
123133
continue
124134

125135
if input_dir.url is not None or input_dir.path is not None:
126-
# 6. Download all the required paths to unblock the current index
136+
if input_dir.url:
137+
# 6. Wait for the removers to catch up when we are downloading data.
138+
_wait_for_disk_usage_higher_than_threshold("/", 25)
139+
140+
# 7. Download all the required paths to unblock the current index
127141
for path in paths:
128142
local_path = path.replace(input_dir.path, cache_dir)
129143

@@ -141,7 +155,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
141155
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
142156

143157
elif os.path.isfile(path):
144-
copyfile(path, local_path)
158+
shutil.copyfile(path, local_path)
145159
else:
146160
raise ValueError(f"The provided {input_dir.url} isn't supported.")
147161

@@ -198,7 +212,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
198212
except Exception as e:
199213
print(e)
200214
elif os.path.isdir(output_dir.path):
201-
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
215+
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
202216
else:
203217
raise ValueError(f"The provided {output_dir.path} isn't supported.")
204218

@@ -686,7 +700,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
686700
local_filepath, obj.netloc, os.path.join(obj.path.lstrip("/"), os.path.basename(local_filepath))
687701
)
688702
elif os.path.isdir(output_dir.path):
689-
copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
703+
shutil.copyfile(local_filepath, os.path.join(output_dir.path, os.path.basename(local_filepath)))
690704

691705
if num_nodes == 1 or node_rank is None:
692706
return
@@ -707,7 +721,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
707721
with open(node_index_filepath, "wb") as f:
708722
s3.client.download_fileobj(obj.netloc, obj.path.lstrip("/"), f)
709723
elif os.path.isdir(output_dir.path):
710-
copyfile(remote_filepath, node_index_filepath)
724+
shutil.copyfile(remote_filepath, node_index_filepath)
711725

712726
merge_cache = Cache(cache_dir, chunk_bytes=1)
713727
merge_cache._merge_no_wait()
@@ -948,15 +962,15 @@ def _cleanup_cache(self) -> None:
948962

949963
# Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
950964
if os.path.exists(cache_dir):
951-
rmtree(cache_dir, ignore_errors=True)
965+
shutil.rmtree(cache_dir, ignore_errors=True)
952966

953967
os.makedirs(cache_dir, exist_ok=True)
954968

955969
cache_data_dir = _get_cache_data_dir()
956970

957971
# Cleanup the cache data folder to avoid corrupted files from previous run to be there.
958972
if os.path.exists(cache_data_dir):
959-
rmtree(cache_data_dir, ignore_errors=True)
973+
shutil.rmtree(cache_data_dir, ignore_errors=True)
960974

961975
os.makedirs(cache_data_dir, exist_ok=True)
962976

tests/tests_data/streaming/test_data_processor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_map_items_to_workers_weighted,
2222
_remove_target,
2323
_upload_fn,
24+
_wait_for_disk_usage_higher_than_threshold,
2425
_wait_for_file_to_exist,
2526
)
2627
from lightning.data.streaming.functions import LambdaDataTransformRecipe, map, optimize
@@ -159,6 +160,7 @@ def fn(*_, **__):
159160

160161

161162
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
163+
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
162164
def test_download_data_target(tmpdir):
163165
input_dir = os.path.join(tmpdir, "input_dir")
164166
os.makedirs(input_dir, exist_ok=True)
@@ -193,6 +195,13 @@ def fn(*_, **__):
193195
assert os.listdir(cache_dir) == ["a.txt"]
194196

195197

198+
def test_wait_for_disk_usage_higher_than_threshold():
199+
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])
200+
with mock.patch("lightning.data.streaming.data_processor.shutil.disk_usage", disk_usage_mock):
201+
_wait_for_disk_usage_higher_than_threshold("/", 10, sleep_time=0)
202+
assert disk_usage_mock.call_count == 3
203+
204+
196205
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
197206
def test_wait_for_file_to_exist():
198207
import botocore

0 commit comments

Comments
 (0)