1
1
import json
2
2
import logging
3
3
import os
4
+ import shutil
4
5
import signal
5
6
import tempfile
6
7
import traceback
10
11
from datetime import datetime
11
12
from multiprocessing import Process , Queue
12
13
from queue import Empty
13
- from shutil import copyfile , rmtree
14
14
from time import sleep , time
15
15
from typing import Any , Dict , List , Optional , Tuple , TypeVar , Union
16
16
from urllib import parse
@@ -101,6 +101,16 @@ def _wait_for_file_to_exist(s3: S3Client, obj: parse.ParseResult, sleep_time: in
101
101
raise e
102
102
103
103
104
+ def _wait_for_disk_usage_higher_than_threshold (input_dir : str , threshold_in_gb : int = 25 , sleep_time : int = 3 ) -> None :
105
+ usage = shutil .disk_usage (input_dir )
106
+
107
+ while (usage .free / 1000 / 1000 / 1000 ) <= threshold_in_gb :
108
+ sleep (sleep_time )
109
+ usage = shutil .disk_usage (input_dir )
110
+
111
+ return
112
+
113
+
104
114
def _download_data_target (input_dir : Dir , cache_dir : str , queue_in : Queue , queue_out : Queue ) -> None :
105
115
"""This function is used to download data from a remote directory to a cache directory to optimise reading."""
106
116
s3 = S3Client ()
@@ -123,7 +133,11 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
123
133
continue
124
134
125
135
if input_dir .url is not None or input_dir .path is not None :
126
- # 6. Download all the required paths to unblock the current index
136
+ if input_dir .url :
137
+ # 6. Wait for the removers to catch up when we are downloading data.
138
+ _wait_for_disk_usage_higher_than_threshold ("/" , 25 )
139
+
140
+ # 7. Download all the required paths to unblock the current index
127
141
for path in paths :
128
142
local_path = path .replace (input_dir .path , cache_dir )
129
143
@@ -141,7 +155,7 @@ def _download_data_target(input_dir: Dir, cache_dir: str, queue_in: Queue, queue
141
155
s3 .client .download_fileobj (obj .netloc , obj .path .lstrip ("/" ), f )
142
156
143
157
elif os .path .isfile (path ):
144
- copyfile (path , local_path )
158
+ shutil . copyfile (path , local_path )
145
159
else :
146
160
raise ValueError (f"The provided { input_dir .url } isn't supported." )
147
161
@@ -198,7 +212,7 @@ def _upload_fn(upload_queue: Queue, remove_queue: Queue, cache_dir: str, output_
198
212
except Exception as e :
199
213
print (e )
200
214
elif os .path .isdir (output_dir .path ):
201
- copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
215
+ shutil . copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
202
216
else :
203
217
raise ValueError (f"The provided { output_dir .path } isn't supported." )
204
218
@@ -686,7 +700,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
686
700
local_filepath , obj .netloc , os .path .join (obj .path .lstrip ("/" ), os .path .basename (local_filepath ))
687
701
)
688
702
elif os .path .isdir (output_dir .path ):
689
- copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
703
+ shutil . copyfile (local_filepath , os .path .join (output_dir .path , os .path .basename (local_filepath )))
690
704
691
705
if num_nodes == 1 or node_rank is None :
692
706
return
@@ -707,7 +721,7 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
707
721
with open (node_index_filepath , "wb" ) as f :
708
722
s3 .client .download_fileobj (obj .netloc , obj .path .lstrip ("/" ), f )
709
723
elif os .path .isdir (output_dir .path ):
710
- copyfile (remote_filepath , node_index_filepath )
724
+ shutil . copyfile (remote_filepath , node_index_filepath )
711
725
712
726
merge_cache = Cache (cache_dir , chunk_bytes = 1 )
713
727
merge_cache ._merge_no_wait ()
@@ -948,15 +962,15 @@ def _cleanup_cache(self) -> None:
948
962
949
963
# Cleanup the cache dir folder to avoid corrupted files from previous run to be there.
950
964
if os .path .exists (cache_dir ):
951
- rmtree (cache_dir , ignore_errors = True )
965
+ shutil . rmtree (cache_dir , ignore_errors = True )
952
966
953
967
os .makedirs (cache_dir , exist_ok = True )
954
968
955
969
cache_data_dir = _get_cache_data_dir ()
956
970
957
971
# Cleanup the cache data folder to avoid corrupted files from previous run to be there.
958
972
if os .path .exists (cache_data_dir ):
959
- rmtree (cache_data_dir , ignore_errors = True )
973
+ shutil . rmtree (cache_data_dir , ignore_errors = True )
960
974
961
975
os .makedirs (cache_data_dir , exist_ok = True )
962
976
0 commit comments