Skip to content

Commit 9dfce03

Browse files
fineguyThe TensorFlow Datasets Authors
authored andcommitted
Add subfolder option to incomplete_files and use it in writer.py.
PiperOrigin-RevId: 804868748
1 parent 57cad96 commit 9dfce03

File tree

3 files changed

+33
-17
lines changed

3 files changed

+33
-17
lines changed

tensorflow_datasets/core/utils/py_utils.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _tmp_file_name(
312312
path: epath.PathLike,
313313
subfolder: str | None = None,
314314
) -> epath.Path:
315-
"""Returns the temporary file name for the given path.
315+
"""Returns the temporary file path dependent on the given path and subfolder.
316316
317317
Args:
318318
path: The path to the file.
@@ -322,9 +322,12 @@ def _tmp_file_name(
322322
path = epath.Path(path)
323323
file_name = f'{_tmp_file_prefix()}.{path.name}'
324324
if subfolder:
325-
return path.parent / subfolder / file_name
325+
tmp_path = path.parent / subfolder / file_name
326326
else:
327-
return path.parent / file_name
327+
tmp_path = path.parent / file_name
328+
# Create the parent directory if it doesn't exist.
329+
tmp_path.parent.mkdir(parents=True, exist_ok=True)
330+
return tmp_path
328331

329332

330333
@contextlib.contextmanager
@@ -334,7 +337,6 @@ def incomplete_file(
334337
) -> Iterator[epath.Path]:
335338
"""Writes to path atomically, by writing to temp file and renaming it."""
336339
tmp_path = _tmp_file_name(path, subfolder=subfolder)
337-
tmp_path.parent.mkdir(exist_ok=True)
338340
try:
339341
yield tmp_path
340342
tmp_path.replace(path)
@@ -346,20 +348,24 @@ def incomplete_file(
346348
@contextlib.contextmanager
347349
def incomplete_files(
348350
path: epath.Path,
351+
subfolder: str | None = None,
349352
) -> Iterator[epath.Path]:
350353
"""Writes to path atomically, by writing to temp file and renaming it."""
351-
tmp_file_prefix = _tmp_file_prefix()
352-
tmp_path = path.parent / f'{tmp_file_prefix}.{path.name}'
354+
tmp_path = _tmp_file_name(path, subfolder=subfolder)
355+
tmp_file_prefix = tmp_path.name.removesuffix(f'.{path.name}')
353356
try:
354357
yield tmp_path
355358
# Rename all tmp files to their final name.
356-
for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'):
359+
for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'):
357360
file_name = tmp_file.name.removeprefix(tmp_file_prefix + '.')
358361
tmp_file.replace(path.parent / file_name)
359362
finally:
360363
# Eventually delete the tmp_path if exception was raised
361-
for tmp_file in path.parent.glob(f'{tmp_file_prefix}.*'):
362-
tmp_file.unlink(missing_ok=True)
364+
if subfolder:
365+
tmp_path.parent.unlink(missing_ok=True)
366+
else:
367+
for tmp_file in tmp_path.parent.glob(f'{tmp_file_prefix}.*'):
368+
tmp_file.unlink(missing_ok=True)
363369

364370

365371
def is_incomplete_file(path: epath.Path) -> bool:

tensorflow_datasets/core/utils/py_utils_test.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -373,15 +373,23 @@ def test_make_valid_name(name: str, expected: str):
373373

374374

375375
@pytest.mark.parametrize(
376-
['path', 'subfolder', 'expected'],
376+
['file_name', 'subfolder', 'expected_tmp_file_name'],
377377
[
378-
('/a/file.ext', None, '/a/foobar.file.ext'),
379-
('/a/file.ext', 'sub', '/a/sub/foobar.file.ext'),
378+
('file.ext', None, 'foobar.file.ext'),
379+
('file.ext', 'sub', 'sub/foobar.file.ext'),
380380
],
381381
)
382-
def test_tmp_file_name(path, subfolder, expected):
382+
def test_tmp_file_name(
383+
tmp_path: pathlib.Path,
384+
file_name: str,
385+
subfolder: str | None,
386+
expected_tmp_file_name: str,
387+
):
383388
with mock.patch.object(py_utils, '_tmp_file_prefix', return_value='foobar'):
384-
assert os.fspath(py_utils._tmp_file_name(path, subfolder)) == expected
389+
assert (
390+
py_utils._tmp_file_name(tmp_path / file_name, subfolder)
391+
== tmp_path / expected_tmp_file_name
392+
)
385393

386394

387395
if __name__ == '__main__':

tensorflow_datasets/core/writer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -572,12 +572,14 @@ def _write_final_shard(
572572
shard_path = self._filename_template.sharded_filepath(
573573
shard_index=shard_id, num_shards=len(non_empty_shard_ids)
574574
)
575-
with utils.incomplete_files(epath.Path(shard_path)) as tmp_path:
575+
with utils.incomplete_files(shard_path, subfolder="tmp") as tmp_shard_path:
576576
logging.info(
577-
"Writing %d examples to %s.", len(example_by_key), os.fspath(tmp_path)
577+
"Writing %d examples to %s.",
578+
len(example_by_key),
579+
os.fspath(tmp_shard_path),
578580
)
579581
record_keys = self._example_writer.write(
580-
tmp_path, sorted(example_by_key.items())
582+
tmp_shard_path, sorted(example_by_key.items())
581583
)
582584
self.inc_counter(name="written_shards")
583585
# If there are record_keys, create index files.

0 commit comments

Comments
 (0)