Skip to content

Commit 38c3c63

Browse files
thomasthomas
authored andcommitted
update
1 parent e524356 commit 38c3c63

File tree

8 files changed

+101
-56
lines changed

8 files changed

+101
-56
lines changed

src/lightning/data/streaming/__init__.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
1313

14-
from lightning_cloud.resolver import Dir as InputDir
15-
1614
from lightning.data.streaming.cache import Cache
1715
from lightning.data.streaming.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
1816
from lightning.data.streaming.dataset import StreamingDataset
@@ -25,5 +23,4 @@
2523
"DataTransformRecipe",
2624
"DataChunkRecipe",
2725
"TokensLoader",
28-
"InputDir",
2926
]

src/lightning/data/streaming/cache.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ def filled(self) -> bool:
107107
return self._is_done
108108

109109
@property
110-
def resume_folder(self) -> str:
111-
resume_folder = os.path.join(self._cache_dir, "checkpoints", str(self._reader.rank))
112-
if not os.path.exists(resume_folder):
113-
os.makedirs(resume_folder, exist_ok=True)
114-
return resume_folder
110+
def checkpoint_dir(self) -> str:
111+
checkpoint_dir = os.path.join(self._cache_dir, "checkpoints", str(self.rank))
112+
if not os.path.exists(checkpoint_dir):
113+
os.makedirs(checkpoint_dir, exist_ok=True)
114+
return checkpoint_dir
115115

116116
def __setitem__(self, index: int, data: Any) -> None:
117117
"""Store an item in the writer."""

src/lightning/data/streaming/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,5 @@
5151
18: torch.long,
5252
19: torch.bool,
5353
}
54+
55+
_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"

src/lightning/data/streaming/dataset.py

Lines changed: 77 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import json
1616
import os
1717
import shutil
18-
import uuid
1918
from dataclasses import dataclass
2019
from datetime import datetime
2120
from time import time
@@ -25,7 +24,12 @@
2524
from torch.utils.data import IterableDataset
2625

2726
from lightning.data.streaming import Cache
28-
from lightning.data.streaming.constants import _DEFAULT_CACHE_DIR, _INDEX_FILENAME, _LIGHTNING_CLOUD_LATEST
27+
from lightning.data.streaming.constants import (
28+
_DEFAULT_CACHE_DIR,
29+
_INDEX_FILENAME,
30+
_LIGHTNING_CLOUD_LATEST,
31+
_TIME_FORMAT,
32+
)
2933
from lightning.data.streaming.item_loader import BaseItemLoader
3034
from lightning.data.streaming.sampler import ChunkedIndex
3135
from lightning.data.streaming.serializers import Serializer
@@ -93,7 +97,6 @@ def __init__(
9397
self.random_state = None
9498
self.shuffler: Optional[Shuffle] = None
9599
self.serializers = serializers
96-
self.resume_id = uuid.uuid4()
97100
self.checkpoint_interval = checkpoint_interval or 60 * 5
98101
self._state_dict: Optional[Dict] = None
99102

@@ -154,13 +157,17 @@ def __iter__(self) -> "StreamingDataset":
154157

155158
# Handle restart
156159
if self._state_dict:
160+
self._validate_state_dict()
157161
state = self._state_dict[str(self.cache.rank)]
162+
158163
self.chunk_index = state["chunk_index"]
159164
self.global_index = state["global_index"]
160165
self.index = state["index"]
166+
self.current_epoch = state["current_epoch"]
167+
161168
interval = self.worker_intervals[self.chunk_index]
162169
current_indexes = np.arange(interval[0], interval[1])
163-
current_indexes = self.shuffler(current_indexes)
170+
current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
164171
self.current_indexes = current_indexes[state["index"] :]
165172
self.has_triggered_download = False
166173
self.last_time = time()
@@ -200,17 +207,15 @@ def __next__(self) -> Any:
200207
self.index = 0
201208

202209
# Checkpoint when reaching a new chunk
203-
self.checkpoint()
210+
self.checkpoint(self.chunk_index)
204211

205212
interval = self.worker_intervals[self.chunk_index]
206213
current_indexes = np.arange(interval[0], interval[1])
207214

208215
assert self.shuffler is not None
209-
self.current_indexes = self.shuffler(current_indexes)
216+
self.current_indexes = self.shuffler(current_indexes, self.current_epoch, self.chunk_index)
210217
self.chunk_index += 1
211218

212-
last_index = self.chunk_index == len(self.worker_intervals) and len(self.current_indexes) == 1
213-
214219
# Get the first index
215220
index = self.current_indexes.pop(0)
216221

@@ -221,7 +226,7 @@ def __next__(self) -> Any:
221226
chunk_index=self.worker_chunks[self.chunk_index - 1],
222227
# We provide the chunks indexes only one the first
223228
chunk_indexes=None if self.has_triggered_download else self.worker_chunks,
224-
last_index=last_index,
229+
last_index=(self.chunk_index - 1) == len(self.worker_intervals) and len(self.current_indexes) == 1,
225230
)
226231
)
227232

@@ -231,37 +236,38 @@ def __next__(self) -> Any:
231236

232237
# Checkpoint based on time
233238
if (self.last_time - time()) > self.checkpoint_interval:
234-
self.checkpoint()
239+
self.checkpoint(self.chunk_index - 1)
235240

236241
return data
237242

238-
def checkpoint(self) -> None:
243+
def checkpoint(self, chunk_index: int) -> None:
239244
import tempfile
240245

241246
with tempfile.NamedTemporaryFile(mode="w+") as tmp:
247+
# 1. Write the state to a tempfile
242248
json.dump(
243249
{
244250
"rank": self.cache._reader.rank,
245251
"current_epoch": self.current_epoch,
246252
"input_dir_path": self.input_dir.path,
247253
"input_dir_url": self.input_dir.url,
248-
"item_loader": self.item_loader.state_dict(),
254+
"item_loader": self.item_loader.state_dict() if self.item_loader else None,
249255
"drop_last": self.drop_last,
250256
"seed": self.seed,
251257
"checkpoint_interval": self.checkpoint_interval,
252-
"chunk_index": self.chunk_index,
258+
"chunk_index": chunk_index,
253259
"global_index": self.global_index,
254260
"index": self.index,
255261
},
256262
tmp,
257263
)
258264

265+
# 2. Flush to make sure it is written
259266
tmp.flush()
260267

261-
now = datetime.now().strftime("%Y-%m-%d_%H-%M-%S.%fZ")
262-
checkpoint_path = os.path.join(self.cache.resume_folder, f"checkpoint-{now}.json")
263-
264-
# Should avoid corrupted read from the main thread.
268+
# 3. Move the file to avoid corrupted read from the main thread.
269+
now = datetime.now().strftime(_TIME_FORMAT)
270+
checkpoint_path = os.path.join(self.cache.checkpoint_dir, f"checkpoint-{now}.json")
265271
shutil.copyfile(tmp.name, checkpoint_path)
266272

267273
self.last_time = time()
@@ -274,18 +280,17 @@ def state_dict(self) -> Dict[_DictKey, Any]:
274280
state_dict = {}
275281
worker_env = _WorkerEnv.detect()
276282
if worker_env.world_size == 1:
277-
checkpoint_dir = os.path.join(self.cache._cache_dir, "checkpoints")
278-
if not os.path.exists(checkpoint_dir):
283+
# 1. Check whether the checkpoint_dir exists
284+
if not os.path.exists(self.cache.checkpoint_dir):
279285
return state_dict
280-
for worker_idx in os.listdir(checkpoint_dir):
281-
checkpoints = os.listdir(os.path.join(checkpoint_dir, str(worker_idx)))
282-
checkpoints = sorted(
283-
checkpoints,
284-
key=lambda item: datetime.strptime(
285-
item.split("checkpoint-")[1].split(".json")[0], "%Y-%m-%d_%H-%M-%S.%fZ"
286-
),
287-
)
288-
checkpoint_path = os.path.join(checkpoint_dir, str(worker_idx), checkpoints[-1])
286+
287+
# 2. Iterate through the workers and read the latest checkpoint
288+
for worker_idx in os.listdir(self.cache.checkpoint_dir):
289+
checkpoints = os.listdir(os.path.join(self.cache.checkpoint_dir, str(worker_idx)))
290+
checkpoints = sorted(checkpoints, key=_string_to_datetime)
291+
292+
# Load the latest checkpoint for this worker
293+
checkpoint_path = os.path.join(self.cache.checkpoint_dir, str(worker_idx), checkpoints[-1])
289294
with open(checkpoint_path) as f:
290295
state_dict[worker_idx] = json.load(f)
291296
else:
@@ -296,6 +301,46 @@ def load_state_dict(self, state_dict: Dict[_DictKey, Any]) -> None:
296301
if state_dict:
297302
self._state_dict = state_dict
298303

304+
def _validate_state_dict(self) -> None:
305+
env = Environment(dist_env=self.distributed_env, worker_env=self.worker_env)
306+
307+
if env.num_shards != len(self._state_dict):
308+
raise ValueError(
309+
"The provided `state` doesn't match the number workers world size. "
310+
f"Found {env.num_shards} instead of {len(self._state_dict)}."
311+
)
312+
313+
state = self._state_dict[str(self.cache.rank)]
314+
315+
if state["input_dir_path"] != self.input_dir.path:
316+
raise ValueError(
317+
"The provided `input_dir` path doesn't match the current one. "
318+
f"Found {self.input_dir.path} instead of {state['input_dir_path']}."
319+
)
320+
321+
if state["input_dir_url"] != self.input_dir.url:
322+
raise ValueError(
323+
"The provided `input_dir` URL doesn't match the current one. "
324+
f"Found {self.input_dir.url} instead of {state['input_dir_url']}."
325+
)
326+
327+
if state["seed"] != self.seed:
328+
raise ValueError(
329+
"The provided `seed` doesn't match the current one. " f"Found {self.seed} instead of {state['seed']}."
330+
)
331+
332+
if self.item_loader and state["item_loader"] != self.item_loader.state_dict():
333+
raise ValueError(
334+
"The provided `item_loader` state doesn't match the current one. "
335+
f"Found {self.item_loader.state_dict()} instead of {state['item_loader']}."
336+
)
337+
338+
if state["drop_last"] != self.drop_last:
339+
raise ValueError(
340+
"The provided `drop_last` state doesn't match the current one. "
341+
f"Found {self.drop_last} instead of {state['drop_last']}."
342+
)
343+
299344

300345
def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
301346
hash_object = hashlib.md5(input_dir.encode())
@@ -308,6 +353,10 @@ def _try_create_cache_dir(input_dir: str, shard_rank: int = 0) -> Optional[str]:
308353
return cache_dir
309354

310355

356+
def _string_to_datetime(item: str) -> datetime:
357+
return datetime.strptime(item.split("checkpoint-")[1].split(".json")[0], _TIME_FORMAT)
358+
359+
311360
@dataclass
312361
class RemoteDir:
313362
"""Holds a remote URL to a directory and a cache directory where the data will be downloaded."""

src/lightning/data/streaming/shuffle.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ def __init__(self, cache: Cache, seed: int, drop_last: bool):
2828
self.cache = cache
2929
self.seed = seed
3030
self.drop_last = drop_last
31-
self.random_state = None
3231

3332
@lru_cache(maxsize=10)
3433
def get_len(self, distributed_env: _DistributedEnv, current_epoch: int) -> int:
@@ -48,7 +47,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
4847
pass
4948

5049
@abstractmethod
51-
def __call__(self, array: np.ndarray) -> List[int]:
50+
def __call__(self, array: np.ndarray, current_epoch: int) -> List[int]:
5251
pass
5352

5453

@@ -68,7 +67,7 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
6867

6968
return chunks_per_ranks, intervals_per_ranks
7069

71-
def __call__(self, array: np.ndarray) -> List[int]:
70+
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
7271
return array.tolist()
7372

7473

@@ -92,14 +91,12 @@ class FullShuffle(Shuffle):
9291

9392
@lru_cache(maxsize=10)
9493
def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, current_epoch: int) -> Any:
95-
self.random_state = np.random.RandomState(seed=self.seed + current_epoch) # type: ignore
96-
9794
# 1. Get the intervals
9895
chunk_intervals = self.cache.get_chunk_intervals()
9996

10097
# 2. Shuffle them
10198
indexes = range(len(chunk_intervals))
102-
shuffled_indexes = self.random_state.permutation(indexes)
99+
shuffled_indexes = np.random.RandomState(seed=self.seed + current_epoch).permutation(indexes)
103100
shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes]
104101

105102
# 3. Compute the items budget of each rank
@@ -147,6 +144,5 @@ def get_chunks_and_intervals_per_ranks(self, distributed_env: _DistributedEnv, c
147144

148145
return chunks_per_ranks, intervals_per_ranks
149146

150-
def __call__(self, array: np.ndarray) -> List[int]:
151-
assert self.random_state
152-
return self.random_state.permutation(array).tolist()
147+
def __call__(self, array: np.ndarray, current_epoch: int, chunk_index: int) -> List[int]:
148+
return np.random.RandomState(seed=self.seed + current_epoch + chunk_index).permutation(array).tolist()

tests/tests_data/streaming/test_data_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def fn(*_, **__):
161161

162162
@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows")
163163
@mock.patch("lightning.data.streaming.data_processor._wait_for_disk_usage_higher_than_threshold")
164-
def test_download_data_target(tmpdir):
164+
def test_download_data_target(wait_for_disk_usage_higher_than_threshold_mock, tmpdir):
165165
input_dir = os.path.join(tmpdir, "input_dir")
166166
os.makedirs(input_dir, exist_ok=True)
167167

@@ -194,6 +194,8 @@ def fn(*_, **__):
194194

195195
assert os.listdir(cache_dir) == ["a.txt"]
196196

197+
wait_for_disk_usage_higher_than_threshold_mock.assert_called()
198+
197199

198200
def test_wait_for_disk_usage_higher_than_threshold():
199201
disk_usage_mock = mock.Mock(side_effect=[mock.Mock(free=10e9), mock.Mock(free=10e9), mock.Mock(free=10e11)])

tests/tests_data/streaming/test_dataset.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
161161
dataset_iter = iter(dataset)
162162
assert len(dataset_iter) == 548
163163
process_1_1 = list(dataset_iter)
164-
assert process_1_1[:10] == [785, 788, 782, 783, 789, 787, 786, 781, 784, 780]
164+
assert process_1_1[:10] == [788, 781, 785, 780, 787, 782, 789, 784, 783, 786]
165165
assert len(process_1_1) == 548
166166

167167
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
@@ -172,7 +172,7 @@ def test_streaming_dataset_distributed_full_shuffle_odd(drop_last, tmpdir):
172172
dataset_2_iter = iter(dataset_2)
173173
assert len(dataset_2_iter) == 548 + int(not drop_last)
174174
process_2_1 = list(dataset_2_iter)
175-
assert process_2_1[:10] == [939, 938, 252, 259, 257, 255, 258, 253, 250, 251]
175+
assert process_2_1[:10] == [939, 938, 253, 259, 256, 258, 252, 255, 251, 257]
176176
assert len(process_2_1) == 548 + int(not drop_last)
177177
assert len([i for i in process_1_1 if i in process_2_1]) == 0
178178

@@ -201,7 +201,7 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
201201
dataset_iter = iter(dataset)
202202
assert len(dataset_iter) == 611
203203
process_1_1 = list(dataset_iter)
204-
assert process_1_1[:10] == [185, 184, 182, 189, 187, 181, 183, 180, 186, 188]
204+
assert process_1_1[:10] == [188, 181, 185, 180, 187, 182, 189, 184, 183, 186]
205205
assert len(process_1_1) == 611
206206

207207
dataset_2 = StreamingDataset(input_dir=str(tmpdir), shuffle=True, drop_last=drop_last)
@@ -212,9 +212,8 @@ def test_streaming_dataset_distributed_full_shuffle_even(drop_last, tmpdir):
212212
dataset_2_iter = iter(dataset_2)
213213
assert len(dataset_2_iter) == 611
214214
process_2_1 = list(dataset_2_iter)
215-
assert process_2_1[:10] == [813, 815, 816, 812, 818, 811, 817, 814, 819, 277]
215+
assert process_2_1[:10] == [818, 812, 816, 811, 819, 813, 815, 814, 817, 273]
216216
assert len(process_2_1) == 611
217-
218217
assert len([i for i in process_1_1 if i in process_2_1]) == 0
219218

220219

@@ -530,7 +529,7 @@ def test_s3_streaming_dataset():
530529
assert dataset.input_dir.path is None
531530

532531

533-
def test_resumable_dataset(tmpdir):
532+
def test_resumable_dataset_single_worker(tmpdir):
534533
seed_everything(42)
535534

536535
block_size = 20

tests/tests_pytorch/checkpointing/test_model_checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1474,8 +1474,8 @@ def test_train_epoch_end_ckpt_with_no_validation():
14741474
assert not trainer.checkpoint_callback._should_save_on_train_epoch_end(trainer)
14751475

14761476

1477-
@pytest.mark.parametrize("same_resume_folder", [True, False])
1478-
def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
1477+
@pytest.mark.parametrize("same_checkpoint_dir", [True, False])
1478+
def test_resume_and_old_checkpoint_files_remain(same_checkpoint_dir, tmp_path):
14791479
"""Test that checkpoints saved in the resume-folder won't be deleted under the save-top-k mechanism."""
14801480
model = BoringModel()
14811481
trainer_kwargs = {
@@ -1488,7 +1488,7 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
14881488
}
14891489
first = tmp_path / "first"
14901490
second = tmp_path / "second"
1491-
new_dirpath = first if same_resume_folder else second
1491+
new_dirpath = first if same_checkpoint_dir else second
14921492

14931493
# Generate checkpoints in the first folder
14941494
callback = ModelCheckpoint(dirpath=first, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)
@@ -1500,7 +1500,7 @@ def test_resume_and_old_checkpoint_files_remain(same_resume_folder, tmp_path):
15001500
callback = ModelCheckpoint(dirpath=new_dirpath, monitor="step", mode="max", save_top_k=2, every_n_train_steps=2)
15011501
trainer = Trainer(callbacks=callback, max_steps=8, **trainer_kwargs)
15021502
trainer.fit(model, ckpt_path=str(first / "epoch=0-step=4.ckpt"))
1503-
if same_resume_folder:
1503+
if same_checkpoint_dir:
15041504
assert set(os.listdir(first)) == {
15051505
"epoch=0-step=4.ckpt", # do not delete checkpoint from which we resume from
15061506
"epoch=0-step=6.ckpt",

0 commit comments

Comments
 (0)