Skip to content

Commit 890b2be

Browse files
thomasthomas
authored andcommitted
update
1 parent ac18d05 commit 890b2be

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/tests_data/cache/test_cache.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from lightning import seed_everything
2222
from lightning.data.datasets.env import _DistributedEnv
2323
from lightning.data.streaming import Cache
24+
from lightning.data.streaming import cache as cache_module
2425
from lightning.data.streaming.dataloader import StreamingDataLoader
2526
from lightning.fabric import Fabric
2627
from lightning.pytorch.demos.boring_classes import RandomDataset
@@ -203,3 +204,18 @@ def __len__(self) -> int:
203204
with pytest.raises(ValueError, match="Your dataset items aren't deterministic"):
204205
for batch in dataloader:
205206
pass
207+
208+
209+
def test_cache_with_name(tmpdir, monkeypatch):
210+
with pytest.raises(FileNotFoundError, match="The provided cache directory"):
211+
Cache(name="something")
212+
213+
os.makedirs(os.path.join(tmpdir, "something"), exist_ok=True)
214+
os.makedirs(os.path.join(tmpdir, "remote_dir"), exist_ok=True)
215+
monkeypatch.setattr(cache_module, "_try_create_cache_dir", lambda name: os.path.join(tmpdir, name))
216+
217+
monkeypatch.setattr(cache_module, "_find_remote_dir", lambda name, _: (os.path.join(tmpdir, "remote_dir"), True))
218+
cache = Cache(name="something")
219+
assert cache._writer._chunk_size == 2
220+
assert cache._writer._cache_dir == os.path.join(tmpdir, "something")
221+
assert cache._reader._remote_dir == os.path.join(tmpdir, "remote_dir")

0 commit comments

Comments
 (0)