-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add fault tolerance for the StreamingDataset 1/n #19049
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ee3c8f0
a89c529
e524356
38c3c63
e8042da
733ad85
15400c8
c0404be
073200a
a12ca08
c7db240
cf13f37
c71b559
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,10 @@ def __init__( | |
self._is_done = False | ||
self._distributed_env = _DistributedEnv.detect() | ||
|
||
@property | ||
def rank(self) -> int: | ||
return self._reader.rank | ||
|
||
@property | ||
def filled(self) -> bool: | ||
"""Returns whether the caching phase is done.""" | ||
|
@@ -102,6 +106,20 @@ def filled(self) -> bool: | |
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)) | ||
return self._is_done | ||
|
||
@property | ||
def checkpoint_dir(self) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It shouldn't be necessary to duplicate the code in both of these. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. |
||
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints") | ||
if not os.path.exists(checkpoint_dir): | ||
os.makedirs(checkpoint_dir, exist_ok=True) | ||
return checkpoint_dir | ||
|
||
@property | ||
def checkpoint_rank_dir(self) -> str: | ||
checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank)) | ||
if not os.path.exists(checkpoint_rank_dir): | ||
os.makedirs(checkpoint_rank_dir, exist_ok=True) | ||
return checkpoint_rank_dir | ||
|
||
def __setitem__(self, index: int, data: Any) -> None: | ||
"""Store an item in the writer.""" | ||
self._writer[index] = data | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,3 +51,5 @@ | |
18: torch.long, | ||
19: torch.bool, | ||
} | ||
|
||
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ" |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,20 +12,34 @@ | |
# limitations under the License. | ||
|
||
import hashlib | ||
import json | ||
import os | ||
import shutil | ||
import sys | ||
import tempfile | ||
from copy import deepcopy | ||
from dataclasses import dataclass | ||
from datetime import datetime | ||
from time import time | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
import numpy as np | ||
import torch | ||
from torch.utils.data import IterableDataset | ||
|
||
from lightning.data.streaming import Cache | ||
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST | ||
from lightning.data.streaming.constants import ( | ||
_DEFAULT_CACHE_DIR, | ||
_INDEX_FILENAME, | ||
_LIGHTNING_CLOUD_LATEST, | ||
_TIME_FORMAT, | ||
) | ||
from lightning.data.streaming.item_loader import BaseItemLoader | ||
from lightning.data.streaming.sampler import ChunkedIndex | ||
from lightning.data.streaming.serializers import Serializer | ||
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle | ||
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv | ||
from lightning.fabric.utilities.distributed import group as _group | ||
|
||
if _LIGHTNING_CLOUD_LATEST: | ||
from lightning_cloud.resolver import Dir, _resolve_dir | ||
|
@@ -42,6 +56,7 @@ def __init__( | |
drop_last: bool = False, | ||
seed: int = 42, | ||
serializers: Optional[Dict[str, Serializer]] = None, | ||
checkpoint_interval: int = 60 * 5, | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> None: | ||
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. | ||
|
||
|
@@ -53,6 +68,7 @@ def __init__( | |
all processes/workers return the same amount of data. | ||
seed: Random seed for shuffling. | ||
serializers: The serializers used to serialize and deserialize the chunks. | ||
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress. | ||
|
||
""" | ||
super().__init__() | ||
|
@@ -77,13 +93,16 @@ def __init__( | |
self.worker_intervals: List[List[int]] = [] | ||
self.current_indexes: List[int] = [] | ||
self.chunk_index = 0 | ||
self.global_index = 0 | ||
self.index = 0 | ||
self.has_triggered_download = False | ||
self.min_items_per_replica: Optional[int] = None | ||
self.current_epoch = 0 | ||
self.random_state = None | ||
self.shuffler: Optional[Shuffle] = None | ||
self.serializers = serializers | ||
self.checkpoint_interval = checkpoint_interval | ||
self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None | ||
|
||
def _create_cache(self, worker_env: _WorkerEnv) -> Cache: | ||
env = Environment(dist_env=self.distributed_env, worker_env=worker_env) | ||
|
@@ -109,11 +128,10 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: | |
return cache | ||
|
||
def _create_shuffler(self, cache: Cache) -> Shuffle: | ||
return ( | ||
FullShuffle(cache, self.seed, self.drop_last) | ||
if self.shuffle | ||
else NoShuffle(cache, self.seed, self.drop_last) | ||
) | ||
seed = self.seed | ||
if self._state_dict is not None: | ||
seed = self._state_dict[str(cache.rank)]["seed"] | ||
return FullShuffle(cache, seed, self.drop_last) if self.shuffle else NoShuffle(cache, seed, self.drop_last) | ||
|
||
def __len__(self) -> int: | ||
if self.shuffler is None: | ||
|
@@ -126,6 +144,17 @@ def __iter__(self) -> "StreamingDataset": | |
self.cache = self._create_cache(worker_env=self.worker_env) | ||
self.shuffler = self._create_shuffler(self.cache) | ||
|
||
# Handle restart | ||
if self._state_dict: | ||
self._validate_state_dict() | ||
state = self._state_dict[str(self.cache.rank)] | ||
|
||
# reload indexes | ||
self.chunk_index = state["chunk_index"] | ||
self.global_index = state["global_index"] | ||
self.index = state["index"] | ||
self.current_epoch = state["current_epoch"] | ||
|
||
chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks( | ||
self.distributed_env, self.current_epoch | ||
) | ||
|
@@ -141,10 +170,26 @@ def __iter__(self) -> "StreamingDataset": | |
self.worker_chunks.append(chunk_index) | ||
self.worker_intervals.append(chunk_interval) | ||
|
||
self.current_indexes = [] | ||
self.chunk_index = 0 | ||
self.index = 0 | ||
# Handle restart | ||
if self._state_dict: | ||
state = self._state_dict[str(self.cache.rank)] | ||
|
||
# re-generate indexes | ||
interval = self.worker_intervals[self.chunk_index] | ||
current_indexes = np.arange(interval[0], interval[1]) | ||
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index) | ||
self.current_indexes = current_indexes[state["index"] :] | ||
|
||
# Bump the chunk_index | ||
self.chunk_index += 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why +1 the index? We're reloading it in the line above. If the chunk wasn't complete, we would now miss the remainder? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. state = self._state_dict[str(self.cache.rank)]
# re-generate indexes
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
self.current_indexes = current_indexes[state["index"] :]
# Bump the chunk_index
self.chunk_index += 1 No, it won't. The I will clean this up in another PR. |
||
else: | ||
self.current_indexes = [] | ||
self.chunk_index = 0 | ||
self.global_index = 0 | ||
self.index = 0 | ||
|
||
self.has_triggered_download = False | ||
self.last_time = time() | ||
|
||
return self | ||
|
||
|
@@ -159,7 +204,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any: | |
|
||
def __next__(self) -> Any: | ||
# Prevent to create more batch on a given process | ||
if self.index >= len(self): | ||
if self.global_index >= len(self): | ||
self.current_epoch += 1 | ||
raise StopIteration | ||
|
||
|
@@ -169,14 +214,19 @@ def __next__(self) -> Any: | |
self.current_epoch += 1 | ||
raise StopIteration | ||
|
||
# reset index | ||
self.index = 0 | ||
|
||
# Checkpoint when reaching a new chunk | ||
self.checkpoint(self.chunk_index) | ||
|
||
interval = self.worker_intervals[self.chunk_index] | ||
current_indexes = np.arange(interval[0], interval[1]) | ||
|
||
assert self.shuffler is not None | ||
self.current_indexes = self.shuffler(current_indexes) | ||
self.chunk_index += 1 | ||
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index) | ||
|
||
last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1 | ||
self.chunk_index += 1 | ||
|
||
# Get the first index | ||
index = self.current_indexes.pop(0) | ||
|
@@ -188,15 +238,165 @@ def __next__(self) -> Any: | |
chunk_index=self.worker_chunks[self.chunk_index - 1], | ||
# We provide the chunks indexes only one the first | ||
chunk_indexes=None if self.has_triggered_download else self.worker_chunks, | ||
last_index=last_index, | ||
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1, | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
) | ||
|
||
self.has_triggered_download = True | ||
self.global_index += 1 | ||
self.index += 1 | ||
|
||
# Checkpoint based on time | ||
if (self.last_time - time()) > self.checkpoint_interval: | ||
self.checkpoint(self.chunk_index - 1) | ||
|
||
return data | ||
|
||
def checkpoint(self, chunk_index: int) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be public. The user wouldn't be able to use this effectively, because they can only call it from the main process. And from the main process it never makes sense. I suggest to 1) make it private 2) raise an error if called in main process |
||
# Checkpointing isn't supported for windows | ||
if sys.platform == "win32": | ||
return | ||
tchaton marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
assert self.cache | ||
assert self.worker_env | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
tmp_checkpoint_path = os.path.join(tmpdir, "checkpoint.json") | ||
with open(tmp_checkpoint_path, "w") as f: | ||
# 1. Write the state to a tempfile | ||
json.dump( | ||
{ | ||
"rank": self.cache._reader.rank, | ||
"current_epoch": self.current_epoch, | ||
"input_dir_path": self.input_dir.path, | ||
"input_dir_url": self.input_dir.url, | ||
"item_loader": self.item_loader.state_dict() if self.item_loader else None, | ||
"drop_last": self.drop_last, | ||
"seed": self.seed, | ||
"checkpoint_interval": self.checkpoint_interval, | ||
"chunk_index": chunk_index, | ||
"global_index": self.global_index, | ||
"index": self.index, | ||
"world_size": self.distributed_env.world_size, | ||
"num_workers": self.worker_env.world_size, | ||
"shuffle": self.shuffle, | ||
}, | ||
f, | ||
) | ||
|
||
# 3. Move the file to avoid corrupted read from the main thread. | ||
now = datetime.now().strftime(_TIME_FORMAT) | ||
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is the file time stamped? This will create so many files. IMO we should overwrite the file, because only the current state matters. And since every worker only saves to their dedicated folder, it is safe. |
||
|
||
# 4. Move the file to its target position | ||
shutil.move(tmp_checkpoint_path, checkpoint_path) | ||
|
||
self.last_time = time() | ||
|
||
def state_dict(self) -> Dict[str, Any]: | ||
if self.cache is None: | ||
self.worker_env = _WorkerEnv.detect() | ||
self.cache = self._create_cache(worker_env=self.worker_env) | ||
|
||
state_dict: Dict[str, Any] = {} | ||
worker_env = _WorkerEnv.detect() | ||
if worker_env.world_size == 1: | ||
# 1. Check whether the checkpoint_dir exists | ||
if not os.path.exists(self.cache.checkpoint_dir): | ||
return state_dict | ||
|
||
# 2. Iterate through the workers and read the latest checkpoint | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This step wouldn't be necessary, see comment above |
||
for worker_idx in os.listdir(self.cache.checkpoint_dir): | ||
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx))) | ||
checkpoints = sorted(checkpoints, key=_string_to_datetime) | ||
|
||
# Load the latest checkpoint for this worker | ||
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1]) | ||
with open(checkpoint_path) as f: | ||
state_dict[worker_idx] = json.load(f) | ||
|
||
_state_dict = deepcopy(state_dict) | ||
|
||
if self.distributed_env.world_size > 1: | ||
# TODO: Move this to fabric. | ||
num_devices = torch.cuda.device_count() or 1 | ||
node_ranks = [] | ||
for index in range(self.distributed_env.world_size): | ||
node_rank = index // num_devices | ||
if node_rank in node_ranks: | ||
continue | ||
state = {} | ||
obj = [_state_dict] | ||
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD) | ||
state = obj[0] | ||
state_dict.update(**state) | ||
node_ranks.append(node_rank) | ||
Comment on lines
+321
to
+333
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should put this in a function. Way easier to unit test! |
||
else: | ||
raise NotImplementedError("The `state_dict` should be called on the main thread.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But they aren't threads, they are proper processes. And we should raise immediately at the beginning, this would eliminate the entire if-else block, making the code much more readable |
||
return state_dict | ||
|
||
def load_state_dict(self, state_dict: Dict[str, Any]) -> None: | ||
if state_dict: | ||
# the state is restored within the workers | ||
self._state_dict = state_dict | ||
|
||
def _validate_state_dict(self) -> None: | ||
assert self._state_dict | ||
assert self.worker_env | ||
assert self.cache | ||
|
||
env = Environment(dist_env=self.distributed_env, worker_env=self.worker_env) | ||
|
||
if env.num_shards != len(self._state_dict): | ||
raise ValueError( | ||
"The provided `state` size doesn't match the number workers world size. " | ||
f"Found `{env.num_shards}` instead of `{len(self._state_dict)}`." | ||
) | ||
|
||
state: Dict[str, Any] = self._state_dict[str(self.cache.rank)] | ||
|
||
if state["shuffle"] != self.shuffle: | ||
raise ValueError( | ||
"The provided `shuffle` state doesn't match the current one. " | ||
f"Found `{self.shuffle}` instead of `{state['shuffle']}`." | ||
) | ||
|
||
if state["num_workers"] != self.worker_env.world_size: | ||
raise ValueError( | ||
"The provided `num_workers` state doesn't match the current one. " | ||
f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`." | ||
) | ||
|
||
if state["input_dir_path"] != self.input_dir.path: | ||
raise ValueError( | ||
"The provided `input_dir` path state doesn't match the current one. " | ||
f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`." | ||
) | ||
|
||
if state["input_dir_url"] != self.input_dir.url: | ||
raise ValueError( | ||
"The provided `input_dir` URL state doesn't match the current one. " | ||
f"Found `{self.input_dir.url}` instead of `{state['input_dir_url']}`." | ||
) | ||
|
||
if state["seed"] != self.seed: | ||
raise ValueError( | ||
"The provided `seed` state doesn't match the current one. " | ||
f"Found `{self.seed}` instead of `{state['seed']}`." | ||
) | ||
|
||
if self.item_loader and state["item_loader"] != self.item_loader.state_dict(): | ||
raise ValueError( | ||
"The provided `item_loader` state doesn't match the current one. " | ||
f"Found `{self.item_loader.state_dict()}` instead of `{state['item_loader']}`." | ||
) | ||
|
||
if state["drop_last"] != self.drop_last: | ||
raise ValueError( | ||
"The provided `drop_last` state doesn't match the current one. " | ||
f"Found `{self.drop_last}` instead of `{state['drop_last']}`." | ||
) | ||
|
||
|
||
def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]: | ||
hash_object = hashlib.md5(input_dir.encode()) | ||
|
@@ -209,6 +409,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]: | |
return cache_dir | ||
|
||
|
||
def _string_to_datetime(item: str) -> datetime: | ||
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT) | ||
|
||
|
||
@dataclass | ||
class RemoteDir: | ||
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded.""" | ||
|
Uh oh!
There was an error while loading. Please reload this page.