|
21 | 21 | from lightning import seed_everything
|
22 | 22 | from lightning.data.datasets.env import _DistributedEnv
|
23 | 23 | from lightning.data.streaming import Cache
|
| 24 | +from lightning.data.streaming import cache as cache_module |
24 | 25 | from lightning.data.streaming.dataloader import StreamingDataLoader
|
25 | 26 | from lightning.fabric import Fabric
|
26 | 27 | from lightning.pytorch.demos.boring_classes import RandomDataset
|
@@ -203,3 +204,18 @@ def __len__(self) -> int:
|
203 | 204 | with pytest.raises(ValueError, match="Your dataset items aren't deterministic"):
|
204 | 205 | for batch in dataloader:
|
205 | 206 | pass
|
| 207 | + |
| 208 | + |
| 209 | +def test_cache_with_name(tmpdir, monkeypatch): |
| 210 | + with pytest.raises(FileNotFoundError, match="The provided cache directory"): |
| 211 | + Cache(name="something") |
| 212 | + |
| 213 | + os.makedirs(os.path.join(tmpdir, "something"), exist_ok=True) |
| 214 | + os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True) |
| 215 | + monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: os.path.join(tmpdir, name)) |
| 216 | + |
| 217 | + monkeypatch.setattr(cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir"), True)) |
| 218 | + cache = Cache(name="something") |
| 219 | + assert cache._writer._chunk_size == 2 |
| 220 | + assert cache._writer._cache_dir == os.path.join(tmpdir, "something") |
| 221 | + assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir") |
0 commit comments