15
15
import json
16
16
import os
17
17
import shutil
18
- import uuid
19
18
from dataclasses import dataclass
20
19
from datetime import datetime
21
20
from time import time
25
24
from torch .utils .data import IterableDataset
26
25
27
26
from lightning .data .streaming import Cache
28
- from lightning .data .streaming .constants import _DEFAULT_CACHE_DIR , _INDEX_FILENAME , _LIGHTNING_CLOUD_LATEST
27
+ from lightning .data .streaming .constants import (
28
+ _DEFAULT_CACHE_DIR ,
29
+ _INDEX_FILENAME ,
30
+ _LIGHTNING_CLOUD_LATEST ,
31
+ _TIME_FORMAT ,
32
+ )
29
33
from lightning .data .streaming .item_loader import BaseItemLoader
30
34
from lightning .data .streaming .sampler import ChunkedIndex
31
35
from lightning .data .streaming .serializers import Serializer
@@ -93,7 +97,6 @@ def __init__(
93
97
self .random_state = None
94
98
self .shuffler : Optional [Shuffle ] = None
95
99
self .serializers = serializers
96
- self .resume_id = uuid .uuid4 ()
97
100
self .checkpoint_interval = checkpoint_interval or 60 * 5
98
101
self ._state_dict : Optional [Dict ] = None
99
102
@@ -154,13 +157,17 @@ def __iter__(self) -> "StreamingDataset":
154
157
155
158
# Handle restart
156
159
if self ._state_dict :
160
+ self ._validate_state_dict ()
157
161
state = self ._state_dict [str (self .cache .rank )]
162
+
158
163
self .chunk_index = state ["chunk_index" ]
159
164
self .global_index = state ["global_index" ]
160
165
self .index = state ["index" ]
166
+ self .current_epoch = state ["current_epoch" ]
167
+
161
168
interval = self .worker_intervals [self .chunk_index ]
162
169
current_indexes = np .arange (interval [0 ], interval [1 ])
163
- current_indexes = self .shuffler (current_indexes )
170
+ current_indexes = self .shuffler (current_indexes , self . current_epoch , self . chunk_index )
164
171
self .current_indexes = current_indexes [state ["index" ] :]
165
172
self .has_triggered_download = False
166
173
self .last_time = time ()
@@ -200,17 +207,15 @@ def __next__(self) -> Any:
200
207
self .index = 0
201
208
202
209
# Checkpoint when reaching a new chunk
203
- self .checkpoint ()
210
+ self .checkpoint (self . chunk_index )
204
211
205
212
interval = self .worker_intervals [self .chunk_index ]
206
213
current_indexes = np .arange (interval [0 ], interval [1 ])
207
214
208
215
assert self .shuffler is not None
209
- self .current_indexes = self .shuffler (current_indexes )
216
+ self .current_indexes = self .shuffler (current_indexes , self . current_epoch , self . chunk_index )
210
217
self .chunk_index += 1
211
218
212
- last_index = self .chunk_index == len (self .worker_intervals ) and len (self .current_indexes ) == 1
213
-
214
219
# Get the first index
215
220
index = self .current_indexes .pop (0 )
216
221
@@ -221,7 +226,7 @@ def __next__(self) -> Any:
221
226
chunk_index = self .worker_chunks [self .chunk_index - 1 ],
222
227
# We provide the chunks indexes only one the first
223
228
chunk_indexes = None if self .has_triggered_download else self .worker_chunks ,
224
- last_index = last_index ,
229
+ last_index = ( self . chunk_index - 1 ) == len ( self . worker_intervals ) and len ( self . current_indexes ) == 1 ,
225
230
)
226
231
)
227
232
@@ -231,37 +236,38 @@ def __next__(self) -> Any:
231
236
232
237
# Checkpoint based on time
233
238
if (self .last_time - time ()) > self .checkpoint_interval :
234
- self .checkpoint ()
239
+ self .checkpoint (self . chunk_index - 1 )
235
240
236
241
return data
237
242
238
- def checkpoint (self ) -> None :
243
+ def checkpoint (self , chunk_index : int ) -> None :
239
244
import tempfile
240
245
241
246
with tempfile .NamedTemporaryFile (mode = "w+" ) as tmp :
247
+ # 1. Write the state to a tempfile
242
248
json .dump (
243
249
{
244
250
"rank" : self .cache ._reader .rank ,
245
251
"current_epoch" : self .current_epoch ,
246
252
"input_dir_path" : self .input_dir .path ,
247
253
"input_dir_url" : self .input_dir .url ,
248
- "item_loader" : self .item_loader .state_dict (),
254
+ "item_loader" : self .item_loader .state_dict () if self . item_loader else None ,
249
255
"drop_last" : self .drop_last ,
250
256
"seed" : self .seed ,
251
257
"checkpoint_interval" : self .checkpoint_interval ,
252
- "chunk_index" : self . chunk_index ,
258
+ "chunk_index" : chunk_index ,
253
259
"global_index" : self .global_index ,
254
260
"index" : self .index ,
255
261
},
256
262
tmp ,
257
263
)
258
264
265
+ # 2. Flush to make sure it is written
259
266
tmp .flush ()
260
267
261
- now = datetime .now ().strftime ("%Y-%m-%d_%H-%M-%S.%fZ" )
262
- checkpoint_path = os .path .join (self .cache .resume_folder , f"checkpoint-{ now } .json" )
263
-
264
- # Should avoid corrupted read from the main thread.
268
+ # 3. Move the file to avoid corrupted read from the main thread.
269
+ now = datetime .now ().strftime (_TIME_FORMAT )
270
+ checkpoint_path = os .path .join (self .cache .checkpoint_dir , f"checkpoint-{ now } .json" )
265
271
shutil .copyfile (tmp .name , checkpoint_path )
266
272
267
273
self .last_time = time ()
@@ -274,18 +280,17 @@ def state_dict(self) -> Dict[_DictKey, Any]:
274
280
state_dict = {}
275
281
worker_env = _WorkerEnv .detect ()
276
282
if worker_env .world_size == 1 :
277
- checkpoint_dir = os . path . join ( self . cache . _cache_dir , "checkpoints" )
278
- if not os .path .exists (checkpoint_dir ):
283
+ # 1. Check whether the checkpoint_dir exists
284
+ if not os .path .exists (self . cache . checkpoint_dir ):
279
285
return state_dict
280
- for worker_idx in os .listdir (checkpoint_dir ):
281
- checkpoints = os .listdir (os .path .join (checkpoint_dir , str (worker_idx )))
282
- checkpoints = sorted (
283
- checkpoints ,
284
- key = lambda item : datetime .strptime (
285
- item .split ("checkpoint-" )[1 ].split (".json" )[0 ], "%Y-%m-%d_%H-%M-%S.%fZ"
286
- ),
287
- )
288
- checkpoint_path = os .path .join (checkpoint_dir , str (worker_idx ), checkpoints [- 1 ])
286
+
287
+ # 2. Iterate through the workers and read the latest checkpoint
288
+ for worker_idx in os .listdir (self .cache .checkpoint_dir ):
289
+ checkpoints = os .listdir (os .path .join (self .cache .checkpoint_dir , str (worker_idx )))
290
+ checkpoints = sorted (checkpoints , key = _string_to_datetime )
291
+
292
+ # Load the latest checkpoint for this worker
293
+ checkpoint_path = os .path .join (self .cache .checkpoint_dir , str (worker_idx ), checkpoints [- 1 ])
289
294
with open (checkpoint_path ) as f :
290
295
state_dict [worker_idx ] = json .load (f )
291
296
else :
@@ -296,6 +301,46 @@ def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
296
301
if state_dict :
297
302
self ._state_dict = state_dict
298
303
304
+ def _validate_state_dict (self ) -> None :
305
+ env = Environment (dist_env = self .distributed_env , worker_env = self .worker_env )
306
+
307
+ if env .num_shards != len (self ._state_dict ):
308
+ raise ValueError (
309
+ "The provided `state` doesn't match the number workers world size. "
310
+ f"Found { env .num_shards } instead of { len (self ._state_dict )} ."
311
+ )
312
+
313
+ state = self ._state_dict [str (self .cache .rank )]
314
+
315
+ if state ["input_dir_path" ] != self .input_dir .path :
316
+ raise ValueError (
317
+ "The provided `input_dir` path doesn't match the current one. "
318
+ f"Found { self .input_dir .path } instead of { state ['input_dir_path' ]} ."
319
+ )
320
+
321
+ if state ["input_dir_url" ] != self .input_dir .url :
322
+ raise ValueError (
323
+ "The provided `input_dir` URL doesn't match the current one. "
324
+ f"Found { self .input_dir .url } instead of { state ['input_dir_url' ]} ."
325
+ )
326
+
327
+ if state ["seed" ] != self .seed :
328
+ raise ValueError (
329
+ "The provided `seed` doesn't match the current one. " f"Found { self .seed } instead of { state ['seed' ]} ."
330
+ )
331
+
332
+ if self .item_loader and state ["item_loader" ] != self .item_loader .state_dict ():
333
+ raise ValueError (
334
+ "The provided `item_loader` state doesn't match the current one. "
335
+ f"Found { self .item_loader .state_dict ()} instead of { state ['item_loader' ]} ."
336
+ )
337
+
338
+ if state ["drop_last" ] != self .drop_last :
339
+ raise ValueError (
340
+ "The provided `drop_last` state doesn't match the current one. "
341
+ f"Found { self .drop_last } instead of { state ['drop_last' ]} ."
342
+ )
343
+
299
344
300
345
def _try_create_cache_dir (input_dir : str , shard_rank : int = 0 ) -> Optional [str ]:
301
346
hash_object = hashlib .md5 (input_dir .encode ())
@@ -308,6 +353,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
308
353
return cache_dir
309
354
310
355
356
+ def _string_to_datetime (item : str ) -> datetime :
357
+ return datetime .strptime (item .split ("checkpoint-" )[1 ].split (".json" )[0 ], _TIME_FORMAT )
358
+
359
+
311
360
@dataclass
312
361
class RemoteDir :
313
362
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
0 commit comments