Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/lightning/data/streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,11 @@
from lightning.data.streaming.dataset import StreamingDataset
from lightning.data.streaming.item_loader import TokensLoader

__all__ = ["Cache", "DataProcessor", "StreamingDataset", "DataTransformRecipe", "DataChunkRecipe", "TokensLoader"]
__all__ = [
"Cache",
"DataProcessor",
"StreamingDataset",
"DataTransformRecipe",
"DataChunkRecipe",
"TokensLoader",
]
18 changes: 18 additions & 0 deletions src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def __init__(
self._is_done = False
self._distributed_env = _DistributedEnv.detect()

@property
def rank(self) -> int:
return self._reader.rank

@property
def filled(self) -> bool:
"""Returns whether the caching phase is done."""
Expand All @@ -102,6 +106,20 @@ def filled(self) -> bool:
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done

@property
def checkpoint_dir(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be necessary to duplicate the code in both of these.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

checkpoint_dir = os.path.join(self._cache_dir, "checkpoints")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
return checkpoint_dir

@property
def checkpoint_rank_dir(self) -> str:
checkpoint_rank_dir = os.path.join(self.checkpoint_dir, str(self.rank))
if not os.path.exists(checkpoint_rank_dir):
os.makedirs(checkpoint_rank_dir, exist_ok=True)
return checkpoint_rank_dir

def __setitem__(self, index: int, data: Any) -> None:
"""Store an item in the writer."""
self._writer[index] = data
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,5 @@
18: torch.long,
19: torch.bool,
}

_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
232 changes: 218 additions & 14 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,34 @@
# limitations under the License.

import hashlib
import json
import os
import shutil
import sys
import tempfile
from copy import deepcopy
from dataclasses import dataclass
from datetime import datetime
from time import time
from typing import Any, Dict, List, Optional, Union

import numpy as np
import torch
from torch.utils.data import IterableDataset

from lightning.data.streaming import Cache
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
from lightning.data.streaming.constants import (
_DEFAULT_CACHE_DIR,
_INDEX_FILENAME,
_LIGHTNING_CLOUD_LATEST,
_TIME_FORMAT,
)
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.serializers import Serializer
from lightning.data.streaming.shuffle import FullShuffle, NoShuffle, Shuffle
from lightning.data.utilities.env import Environment, _DistributedEnv, _WorkerEnv
from lightning.fabric.utilities.distributed import group as _group

if _LIGHTNING_CLOUD_LATEST:
from lightning_cloud.resolver import Dir, _resolve_dir
Expand All @@ -42,6 +56,7 @@ def __init__(
drop_last: bool = False,
seed: int = 42,
serializers: Optional[Dict[str, Serializer]] = None,
checkpoint_interval: int = 60 * 5,
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.

Expand All @@ -53,6 +68,7 @@ def __init__(
all processes/workers return the same amount of data.
seed: Random seed for shuffling.
serializers: The serializers used to serialize and deserialize the chunks.
checkpoint_interval: Interval in seconds at which the workers are going to store their own progress.

"""
super().__init__()
Expand All @@ -77,13 +93,16 @@ def __init__(
self.worker_intervals: List[List[int]] = []
self.current_indexes: List[int] = []
self.chunk_index = 0
self.global_index = 0
self.index = 0
self.has_triggered_download = False
self.min_items_per_replica: Optional[int] = None
self.current_epoch = 0
self.random_state = None
self.shuffler: Optional[Shuffle] = None
self.serializers = serializers
self.checkpoint_interval = checkpoint_interval
self._state_dict: Optional[Dict[str, Dict[str, Any]]] = None

def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
env = Environment(dist_env=self.distributed_env, worker_env=worker_env)
Expand All @@ -109,11 +128,10 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
return cache

def _create_shuffler(self, cache: Cache) -> Shuffle:
return (
FullShuffle(cache, self.seed, self.drop_last)
if self.shuffle
else NoShuffle(cache, self.seed, self.drop_last)
)
seed = self.seed
if self._state_dict is not None:
seed = self._state_dict[str(cache.rank)]["seed"]
return FullShuffle(cache, seed, self.drop_last) if self.shuffle else NoShuffle(cache, seed, self.drop_last)

def __len__(self) -> int:
if self.shuffler is None:
Expand All @@ -126,6 +144,17 @@ def __iter__(self) -> "StreamingDataset":
self.cache = self._create_cache(worker_env=self.worker_env)
self.shuffler = self._create_shuffler(self.cache)

# Handle restart
if self._state_dict:
self._validate_state_dict()
state = self._state_dict[str(self.cache.rank)]

# reload indexes
self.chunk_index = state["chunk_index"]
self.global_index = state["global_index"]
self.index = state["index"]
self.current_epoch = state["current_epoch"]

chunks_per_replica, intervals_per_replica = self.shuffler.get_chunks_and_intervals_per_ranks(
self.distributed_env, self.current_epoch
)
Expand All @@ -141,10 +170,26 @@ def __iter__(self) -> "StreamingDataset":
self.worker_chunks.append(chunk_index)
self.worker_intervals.append(chunk_interval)

self.current_indexes = []
self.chunk_index = 0
self.index = 0
# Handle restart
if self._state_dict:
state = self._state_dict[str(self.cache.rank)]

# re-generate indexes
interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
self.current_indexes = current_indexes[state["index"] :]

# Bump the chunk_index
self.chunk_index += 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why +1 the index? We're reloading it in the line above. If the chunk wasn't complete, we would now miss the remainder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  state = self._state_dict[str(self.cache.rank)]

  # re-generate indexes
  interval = self.worker_intervals[self.chunk_index]
  current_indexes = np.arange(interval[0], interval[1])
  current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
  self.current_indexes = current_indexes[state["index"] :]

  # Bump the chunk_index
  self.chunk_index += 1

No, it won't. The chunk_index is bumped only once the current_indexes are re-created.

I will clean this up in another PR.

else:
self.current_indexes = []
self.chunk_index = 0
self.global_index = 0
self.index = 0

self.has_triggered_download = False
self.last_time = time()

return self

Expand All @@ -159,7 +204,7 @@ def __getitem__(self, index: Union[ChunkedIndex, int]) -> Any:

def __next__(self) -> Any:
# Prevent to create more batch on a given process
if self.index >= len(self):
if self.global_index >= len(self):
self.current_epoch += 1
raise StopIteration

Expand All @@ -169,14 +214,19 @@ def __next__(self) -> Any:
self.current_epoch += 1
raise StopIteration

# reset index
self.index = 0

# Checkpoint when reaching a new chunk
self.checkpoint(self.chunk_index)

interval = self.worker_intervals[self.chunk_index]
current_indexes = np.arange(interval[0], interval[1])

assert self.shuffler is not None
self.current_indexes = self.shuffler(current_indexes)
self.chunk_index += 1
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)

last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1
self.chunk_index += 1

# Get the first index
index = self.current_indexes.pop(0)
Expand All @@ -188,15 +238,165 @@ def __next__(self) -> Any:
chunk_index=self.worker_chunks[self.chunk_index - 1],
# We provide the chunks indexes only one the first
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
last_index=last_index,
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
)
)

self.has_triggered_download = True
self.global_index += 1
self.index += 1

# Checkpoint based on time
if (self.last_time - time()) > self.checkpoint_interval:
self.checkpoint(self.chunk_index - 1)

return data

def checkpoint(self, chunk_index: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be public. The user wouldn't be able to use this effectively, because they can only call it from the main process. And from the main process it never makes sense.

I suggest to 1) make it private 2) raise an error if called in main process

# Checkpointing isn't supported for windows
if sys.platform == "win32":
return

assert self.cache
assert self.worker_env

with tempfile.TemporaryDirectory() as tmpdir:
tmp_checkpoint_path = os.path.join(tmpdir, "checkpoint.json")
with open(tmp_checkpoint_path, "w") as f:
# 1. Write the state to a tempfile
json.dump(
{
"rank": self.cache._reader.rank,
"current_epoch": self.current_epoch,
"input_dir_path": self.input_dir.path,
"input_dir_url": self.input_dir.url,
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
"drop_last": self.drop_last,
"seed": self.seed,
"checkpoint_interval": self.checkpoint_interval,
"chunk_index": chunk_index,
"global_index": self.global_index,
"index": self.index,
"world_size": self.distributed_env.world_size,
"num_workers": self.worker_env.world_size,
"shuffle": self.shuffle,
},
f,
)

# 3. Move the file to avoid corrupted read from the main thread.
now = datetime.now().strftime(_TIME_FORMAT)
checkpoint_path = os.path.join(self.cache.checkpoint_rank_dir, f"checkpoint-{now}.json")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the file time stamped? This will create so many files. IMO we should overwrite the file, because only the current state matters. And since every worker only saves to their dedicated folder, it is safe.


# 4. Move the file to its target position
shutil.move(tmp_checkpoint_path, checkpoint_path)

self.last_time = time()

def state_dict(self) -> Dict[str, Any]:
if self.cache is None:
self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)

state_dict: Dict[str, Any] = {}
worker_env = _WorkerEnv.detect()
if worker_env.world_size == 1:
# 1. Check whether the checkpoint_dir exists
if not os.path.exists(self.cache.checkpoint_dir):
return state_dict

# 2. Iterate through the workers and read the latest checkpoint
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step wouldn't be necessary, see comment above

for worker_idx in os.listdir(self.cache.checkpoint_dir):
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx)))
checkpoints = sorted(checkpoints, key=_string_to_datetime)

# Load the latest checkpoint for this worker
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1])
with open(checkpoint_path) as f:
state_dict[worker_idx] = json.load(f)

_state_dict = deepcopy(state_dict)

if self.distributed_env.world_size > 1:
# TODO: Move this to fabric.
num_devices = torch.cuda.device_count() or 1
node_ranks = []
for index in range(self.distributed_env.world_size):
node_rank = index // num_devices
if node_rank in node_ranks:
continue
state = {}
obj = [_state_dict]
torch.distributed.broadcast_object_list(obj, index, group=_group.WORLD)
state = obj[0]
state_dict.update(**state)
node_ranks.append(node_rank)
Comment on lines +321 to +333
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should put this in a function. Way easier to unit test!

else:
raise NotImplementedError("The `state_dict` should be called on the main thread.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But they aren't threads, they are proper processes. And we should raise immediately at the beginning, this would eliminate the entire if-else block, making the code much more readable

return state_dict

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if state_dict:
# the state is restored within the workers
self._state_dict = state_dict

def _validate_state_dict(self) -> None:
assert self._state_dict
assert self.worker_env
assert self.cache

env = Environment(dist_env=self.distributed_env, worker_env=self.worker_env)

if env.num_shards != len(self._state_dict):
raise ValueError(
"The provided `state` size doesn't match the number workers world size. "
f"Found `{env.num_shards}` instead of `{len(self._state_dict)}`."
)

state: Dict[str, Any] = self._state_dict[str(self.cache.rank)]

if state["shuffle"] != self.shuffle:
raise ValueError(
"The provided `shuffle` state doesn't match the current one. "
f"Found `{self.shuffle}` instead of `{state['shuffle']}`."
)

if state["num_workers"] != self.worker_env.world_size:
raise ValueError(
"The provided `num_workers` state doesn't match the current one. "
f"Found `{self.worker_env.world_size}` instead of `{state['num_workers']}`."
)

if state["input_dir_path"] != self.input_dir.path:
raise ValueError(
"The provided `input_dir` path state doesn't match the current one. "
f"Found `{self.input_dir.path}` instead of `{state['input_dir_path']}`."
)

if state["input_dir_url"] != self.input_dir.url:
raise ValueError(
"The provided `input_dir` URL state doesn't match the current one. "
f"Found `{self.input_dir.url}` instead of `{state['input_dir_url']}`."
)

if state["seed"] != self.seed:
raise ValueError(
"The provided `seed` state doesn't match the current one. "
f"Found `{self.seed}` instead of `{state['seed']}`."
)

if self.item_loader and state["item_loader"] != self.item_loader.state_dict():
raise ValueError(
"The provided `item_loader` state doesn't match the current one. "
f"Found `{self.item_loader.state_dict()}` instead of `{state['item_loader']}`."
)

if state["drop_last"] != self.drop_last:
raise ValueError(
"The provided `drop_last` state doesn't match the current one. "
f"Found `{self.drop_last}` instead of `{state['drop_last']}`."
)


def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
hash_object = hashlib.md5(input_dir.encode())
Expand All @@ -209,6 +409,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
return cache_dir


def _string_to_datetime(item: str) -> datetime:
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)


@dataclass
class RemoteDir:
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""
Expand Down
Loading