Skip to content

Commit 950922d

Browse files
feat: adds video frame streaming source (#4979)
## Changes Made * Adds a `daft.read_video_frames` method. * Adds an example for how to implement av sources. * Initializes the `daft.io.av` package for various audio-video io. * Adds documentation and examples. ## Related Issues n/a ## Checklist - [x] Documented in API Docs (if applicable) - [x] Documented in User Guide (if applicable) - [n/a] If adding a new documentation page, doc is added to `docs/mkdocs.yml` navigation - [x] Documentation builds and is formatted properly (tag @/ccmao1130 for docs review) --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent ff31546 commit 950922d

File tree

11 files changed

+501
-16
lines changed

11 files changed

+501
-16
lines changed

daft/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def refresh_logger() -> None:
8888
read_parquet,
8989
read_sql,
9090
read_lance,
91+
read_video_frames,
9192
read_warc,
9293
)
9394
from daft.series import Series
@@ -211,6 +212,7 @@ def refresh_logger() -> None:
211212
"read_parquet",
212213
"read_sql",
213214
"read_table",
215+
"read_video_frames",
214216
"read_warc",
215217
"refresh_logger",
216218
"register_viz_hook",

daft/ai/typing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from abc import ABC, abstractmethod
44
from dataclasses import dataclass
5-
from typing import Generic, TypeVar
5+
from typing import Any, Generic, TypeVar
66

77
from daft.datatype import DataType
88
from daft.dependencies import np
@@ -36,7 +36,7 @@ def instantiate(self) -> T:
3636

3737

3838
# temp definition to defer complexity of a more generic embedding type to later PRs
39-
Embedding = np.typing.NDArray # type: ignore[type-arg]
39+
Embedding = np.typing.NDArray[Any]
4040

4141

4242
@dataclass

daft/filesystem.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def _resolve_paths_and_filesystem(
115115
paths: str | pathlib.Path | list[str],
116116
io_config: IOConfig | None = None,
117117
) -> tuple[list[str], pafs.FileSystem]:
118-
"""Resolves and normalizes the provided path and infers it's filesystem.
118+
"""Resolves and normalizes the provided path and infers its filesystem.
119119
120120
Also ensures that the inferred filesystem is compatible with the passed filesystem, if provided.
121121
@@ -200,6 +200,8 @@ def _infer_filesystem(
200200
"""
201201
protocol = get_protocol_from_path(path)
202202
translated_kwargs: dict[str, Any]
203+
resolved_filesystem: pafs.FileSystem
204+
expiry: datetime | None = None
203205

204206
def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None) -> None:
205207
"""Helper method used when setting kwargs for pyarrow."""
@@ -228,7 +230,6 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None) -> None:
228230
except ImportError:
229231
pass # Config does not exist in pyarrow 7.0.0
230232

231-
expiry = None
232233
if (s3_creds := s3_config.provide_cached_credentials()) is not None:
233234
_set_if_not_none(translated_kwargs, "access_key", s3_creds.key_id)
234235
_set_if_not_none(translated_kwargs, "secret_key", s3_creds.access_key)
@@ -277,8 +278,8 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None) -> None:
277278
elif protocol in {"http", "https"}:
278279
fsspec_fs_cls = fsspec.get_filesystem_class(protocol)
279280
fsspec_fs = fsspec_fs_cls()
280-
resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs)
281-
resolved_path = resolved_filesystem.normalize_path(resolved_path)
281+
resolved_filesystem = pafs.PyFileSystem(fsspec_fs)
282+
resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path))
282283
return resolved_path, resolved_filesystem, None
283284

284285
###
@@ -300,8 +301,8 @@ def _set_if_not_none(kwargs: dict[str, Any], key: str, val: Any | None) -> None:
300301
)
301302
else:
302303
fsspec_fs = fsspec_fs_cls()
303-
resolved_filesystem, resolved_path = pafs._resolve_filesystem_and_path(path, fsspec_fs)
304-
resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(resolved_path))
304+
resolved_filesystem = pafs.PyFileSystem(fsspec_fs)
305+
resolved_path = resolved_filesystem.normalize_path(_unwrap_protocol(path))
305306
return resolved_path, resolved_filesystem, None
306307

307308
else:

daft/io/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from daft.io.file_path import from_glob_path
2424
from daft.io.sink import DataSink
2525
from daft.io.source import DataSource, DataSourceTask
26+
from daft.io.av import read_video_frames
2627

2728
__all__ = [
2829
"AzureConfig",
@@ -47,5 +48,6 @@
4748
"read_lance",
4849
"read_parquet",
4950
"read_sql",
51+
"read_video_frames",
5052
"read_warc",
5153
]

daft/io/av/__init__.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
from daft.dataframe.dataframe import DataFrame
7+
from daft.daft import IOConfig
8+
9+
__all__ = [
10+
# TODO: additional video support
11+
# "read_audio_frames",
12+
# "read_audio_streams",
13+
# "read_audio_streams_metadata",
14+
# "read_subtitle_frames",
15+
# "read_subtitle_streams",
16+
# "read_subtitle_streams_metadata",
17+
"read_video_frames",
18+
# "read_video_streams",
19+
# "read_video_streams_metadata",
20+
]
21+
22+
23+
def read_video_frames(
24+
path: str | list[str],
25+
image_height: int,
26+
image_width: int,
27+
is_key_frame: bool | None = None,
28+
io_config: IOConfig | None = None,
29+
) -> DataFrame:
30+
"""Creates a DataFrame by reading the frames of one or more video files.
31+
32+
This produces a DataFrame with the following fields:
33+
* path (string): path to the video file that produced this frame.
34+
* frame_index (int): frame index in the video.
35+
* frame_time (float): frame time in fractional seconds as a floating point.
36+
* frame_time_base (str): fractional unit of seconds in which timestamps are expressed.
37+
* frame_pts (int): frame presentation timestamp in time_base units.
38+
* frame_dts (int): frame decoding timestamp in time_base units.
39+
* frame_duration (int): frame duration in time_base units.
40+
* is_key_frame (bool): true iff this is a key frame.
41+
42+
Warning:
43+
This requires PyAV which can be installed with `pip install av`.
44+
45+
Note:
46+
This function will stream the frames from all videos as a DataFrame of images.
47+
If you wish to load an entire video into a single row, this can be done with
48+
read_glob_path and url.download.
49+
50+
Args:
51+
path (str|list[str]): Path(s) to the video file(s) which allows wildcards.
52+
image_height (int): Height to which each frame will be resized.
53+
image_width (int): Width to which each frame will be resized.
54+
is_key_frame (bool|None): If True, only include key frames; if False, only non-key frames; if None, include all frames.
55+
io_config (IOConfig|None): Optional IOConfig.
56+
57+
Returns:
58+
DataFrame: dataframe of images.
59+
60+
Examples:
61+
>>> df = daft.read_video_frames("/path/to/file.mp4", image_height=480, image_width=640)
62+
>>> df = daft.read_video_frames("/path/to/directory", image_height=480, image_width=640)
63+
>>> df = daft.read_video_frames("/path/to/files-*.mp4", image_height=480, image_width=640)
64+
>>> df = daft.read_video_frames("s3://path/to/files-*.mp4", image_height=480, image_width=640)
65+
"""
66+
try:
67+
from daft.io.av._read_video_frames import _VideoFramesSource
68+
except ImportError as e:
69+
raise ImportError("read_video_frames requires PyAV. Please install it with `pip install av`.") from e
70+
return _VideoFramesSource(
71+
paths=[path] if isinstance(path, str) else path,
72+
image_height=image_height,
73+
image_width=image_width,
74+
is_key_frame=is_key_frame,
75+
io_config=io_config,
76+
).read()

0 commit comments

Comments
 (0)