Skip to content

Commit e4e52f8

Browse files
feat: Implement embed_image() (#5101)
## Changes Made Adds support for `embed_image()`, e.g. ``` import daft from daft.functions.ai import embed_image import numpy as np provider = "transformers" model = "openai/clip-vit-base-patch32" test_image = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) ( daft.from_pydict({"image": [test_image] * 16}) .select(daft.col("image").cast(daft.DataType.image())) .select(embed_image(daft.col("image"), provider=provider, model=model)) .show() ) ``` **!! Currently only supports OpenAI CLIP models: https://huggingface.co/docs/transformers/en/model_doc/clip** --------- Co-authored-by: R. C. Howell <[email protected]>
1 parent ce3fd01 commit e4e52f8

File tree

16 files changed

+951
-79
lines changed

16 files changed

+951
-79
lines changed

daft/ai/_expressions.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
if TYPE_CHECKING:
66
from daft import Series
7-
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor
7+
from daft.ai.protocols import ImageEmbedder, ImageEmbedderDescriptor, TextEmbedder, TextEmbedderDescriptor
88
from daft.ai.typing import Embedding
99

1010

@@ -19,3 +19,16 @@ def __init__(self, text_embedder: TextEmbedderDescriptor):
1919
def __call__(self, text_series: Series) -> list[Embedding]:
2020
text = text_series.to_pylist()
2121
return self.text_embedder.embed_text(text) if text else []
22+
23+
24+
class _ImageEmbedderExpression:
25+
"""Function expression implementation for an ImageEmbedder protocol."""
26+
27+
image_embedder: ImageEmbedder
28+
29+
def __init__(self, image_embedder: ImageEmbedderDescriptor):
30+
self.image_embedder = image_embedder.instantiate()
31+
32+
def __call__(self, image_series: Series) -> list[Embedding]:
33+
image = image_series.to_pylist()
34+
return self.image_embedder.embed_image(image) if image else []

daft/ai/openai/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
if TYPE_CHECKING:
1111
from daft.ai.openai.typing import OpenAIProviderOptions
12-
from daft.ai.protocols import TextEmbedder, TextEmbedderDescriptor
12+
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor
1313
from daft.ai.typing import Options
1414

1515
__all__ = [
@@ -36,3 +36,6 @@ def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmb
3636
model_name=(model or "text-embedding-3-small"),
3737
model_options=options,
3838
)
39+
40+
def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
41+
raise NotImplementedError("embed_image is not currently implemented for the OpenAI provider")

daft/ai/protocols.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from daft.ai.typing import Descriptor
77

88
if TYPE_CHECKING:
9-
from daft.ai.typing import Embedding, EmbeddingDimensions
9+
from daft.ai.typing import Embedding, EmbeddingDimensions, Image
1010

1111

1212
@runtime_checkable
@@ -24,3 +24,20 @@ class TextEmbedderDescriptor(Descriptor[TextEmbedder]):
2424
@abstractmethod
2525
def get_dimensions(self) -> EmbeddingDimensions:
2626
"""Returns the dimensions of the embeddings produced by the described TextEmbedder."""
27+
28+
29+
@runtime_checkable
30+
class ImageEmbedder(Protocol):
31+
"""Protocol for image embedding implementations."""
32+
33+
def embed_image(self, images: list[Image]) -> list[Embedding]:
34+
"""Embeds a batch of images into an embedding vector."""
35+
...
36+
37+
38+
class ImageEmbedderDescriptor(Descriptor[ImageEmbedder]):
39+
"""Descriptor for a ImageEmbedder implementation."""
40+
41+
@abstractmethod
42+
def get_dimensions(self) -> EmbeddingDimensions:
43+
"""Returns the dimensions of the embeddings produced by the described ImageEmbedder."""

daft/ai/provider.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
if TYPE_CHECKING:
99
from daft.ai.openai.typing import OpenAIProviderOptions
10-
from daft.ai.protocols import TextEmbedderDescriptor
10+
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor
1111

1212

1313
class ProviderImportError(ImportError):
@@ -34,9 +34,19 @@ def load_sentence_transformers(name: str | None = None, **options: Any) -> Provi
3434
raise ProviderImportError(["sentence_transformers", "torch"]) from e
3535

3636

37+
def load_transformers(name: str | None = None, **options: Any) -> Provider:
38+
try:
39+
from daft.ai.transformers import TransformersProvider
40+
41+
return TransformersProvider(name, **options)
42+
except ImportError as e:
43+
raise ProviderImportError(["torch", "torchvision", "transformers", "Pillow"]) from e
44+
45+
3746
PROVIDERS: dict[str, Callable[..., Provider]] = {
3847
"openai": load_openai,
3948
"sentence_transformers": load_sentence_transformers,
49+
"transformers": load_transformers,
4050
}
4151

4252

@@ -65,3 +75,8 @@ def name(self) -> str:
6575
def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
6676
"""Returns a TextEmbedderDescriptor for this provider."""
6777
...
78+
79+
@abstractmethod
80+
def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
81+
"""Returns an ImageEmbedderDescriptor for this provider."""
82+
...

daft/ai/sentence_transformers/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing import TYPE_CHECKING, Any
77

88
if TYPE_CHECKING:
9-
from daft.ai.protocols import TextEmbedderDescriptor
9+
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor
1010
from daft.ai.typing import Options
1111

1212
__all__ = [
@@ -28,3 +28,6 @@ def name(self) -> str:
2828

2929
def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
3030
return SentenceTransformersTextEmbedderDescriptor(model or "sentence-transformers/all-MiniLM-L6-v2", options)
31+
32+
def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
33+
raise NotImplementedError("embed_image is not currently implemented for the Sentence Transformers provider")

daft/ai/sentence_transformers/text_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def get_options(self) -> Options:
3030
return self.options
3131

3232
def get_dimensions(self) -> EmbeddingDimensions:
33-
dimensions = AutoConfig.from_pretrained(self.model).hidden_size
33+
dimensions = AutoConfig.from_pretrained(self.model, trust_remote_code=True).hidden_size
3434
return EmbeddingDimensions(size=dimensions, dtype=DataType.float32())
3535

3636
def instantiate(self) -> TextEmbedder:

daft/ai/transformers/__init__.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotations
2+
3+
from daft.ai.provider import Provider
4+
5+
from daft.ai.transformers.image_embedder import TransformersImageEmbedderDescriptor
6+
from typing import TYPE_CHECKING, Any
7+
8+
if TYPE_CHECKING:
9+
from daft.ai.protocols import ImageEmbedderDescriptor, TextEmbedderDescriptor
10+
from daft.ai.typing import Options
11+
12+
__all__ = [
13+
"TransformersProvider",
14+
]
15+
16+
17+
class TransformersProvider(Provider):
18+
_name: str
19+
_options: Options
20+
21+
def __init__(self, name: str | None = None, **options: Any):
22+
self._name = name if name else "transformers"
23+
self._options = options
24+
25+
@property
26+
def name(self) -> str:
27+
return self._name
28+
29+
def get_image_embedder(self, model: str | None = None, **options: Any) -> ImageEmbedderDescriptor:
30+
return TransformersImageEmbedderDescriptor(model or "openai/clip-vit-base-patch32", options)
31+
32+
def get_text_embedder(self, model: str | None = None, **options: Any) -> TextEmbedderDescriptor:
33+
raise NotImplementedError("embed_text is not currently implemented for the Transformers provider")
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
from typing import TYPE_CHECKING, Any
5+
6+
import torch
7+
from transformers import AutoConfig, AutoModel, AutoProcessor
8+
9+
from daft import DataType
10+
from daft.ai.protocols import ImageEmbedder, ImageEmbedderDescriptor
11+
from daft.ai.typing import EmbeddingDimensions, Options
12+
from daft.ai.utils import get_device
13+
from daft.dependencies import pil_image
14+
15+
if TYPE_CHECKING:
16+
from daft.ai.typing import Embedding, Image
17+
18+
19+
@dataclass
20+
class TransformersImageEmbedderDescriptor(ImageEmbedderDescriptor):
21+
model: str
22+
options: Options
23+
24+
def get_provider(self) -> str:
25+
return "transformers"
26+
27+
def get_model(self) -> str:
28+
return self.model
29+
30+
def get_options(self) -> Options:
31+
return self.options
32+
33+
def get_dimensions(self) -> EmbeddingDimensions:
34+
config = AutoConfig.from_pretrained(self.model, trust_remote_code=True)
35+
# For CLIP models, the image embedding dimension is typically in projection_dim or hidden_size.
36+
embedding_size = getattr(config, "projection_dim", getattr(config, "hidden_size", 512))
37+
return EmbeddingDimensions(size=embedding_size, dtype=DataType.float32())
38+
39+
def instantiate(self) -> ImageEmbedder:
40+
return TransformersImageEmbedder(self.model, **self.options)
41+
42+
43+
class TransformersImageEmbedder(ImageEmbedder):
44+
model: Any
45+
options: Options
46+
47+
def __init__(self, model_name_or_path: str, **options: Any):
48+
self.device = get_device()
49+
self.model = AutoModel.from_pretrained(
50+
model_name_or_path,
51+
trust_remote_code=True,
52+
use_safetensors=True,
53+
).to(self.device)
54+
self.processor = AutoProcessor.from_pretrained(model_name_or_path, trust_remote_code=True, use_fast=True)
55+
self.options = options
56+
57+
def embed_image(self, images: list[Image]) -> list[Embedding]:
58+
# TODO(desmond): There's potential for image decoding and processing on the GPU with greater
59+
# performance. Methods differ a little between different models, so let's do it later.
60+
pil_images = [pil_image.fromarray(image) for image in images]
61+
processed = self.processor(images=pil_images, return_tensors="pt")
62+
pixel_values = processed["pixel_values"].to(self.device)
63+
64+
with torch.inference_mode():
65+
embeddings = self.model.get_image_features(pixel_values)
66+
return embeddings.cpu().numpy().tolist()

daft/ai/typing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
"Descriptor",
1717
"Embedding",
1818
"EmbeddingDimensions",
19+
"Image",
1920
]
2021

2122

@@ -47,8 +48,10 @@ def instantiate(self) -> T:
4748
from daft.dependencies import np
4849

4950
Embedding: TypeAlias = np.typing.NDArray[Any]
51+
Image: TypeAlias = np.ndarray[Any, Any]
5052
else:
5153
Embedding: TypeAlias = Any
54+
Image: TypeAlias = Any
5255

5356

5457
@dataclass

daft/ai/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
5+
if TYPE_CHECKING:
6+
import torch
7+
8+
9+
def get_device() -> torch.device:
10+
"""Get the best available PyTorch device for computation.
11+
12+
This function automatically selects the optimal device in order of preference:
13+
1. CUDA GPU (if available) - for NVIDIA GPUs with CUDA support
14+
2. MPS (Metal Performance Shaders) - for Apple Silicon Macs
15+
3. CPU - as fallback when no GPU acceleration is available
16+
"""
17+
import torch
18+
19+
device = (
20+
torch.device("cuda")
21+
if torch.cuda.is_available()
22+
else torch.device("mps")
23+
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
24+
else torch.device("cpu")
25+
)
26+
return device

0 commit comments

Comments
 (0)