Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions packages/backend/embedding_atlas/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,12 @@ def find_available_port(start_port: int, max_attempts: int = 10, host="localhost
default=False,
help="Allow execution of remote code when loading models from Hugging Face Hub.",
)
@click.option(
"--batch-size",
type=int,
default=None,
help="Batch size for processing embeddings (default: 32 for text, 16 for images). Larger values use more memory but may be faster.",
)
@click.option(
"--x",
"x_column",
Expand Down Expand Up @@ -207,6 +213,7 @@ def main(
enable_projection: bool,
model: str | None,
trust_remote_code: bool,
batch_size: int | None,
x_column: str | None,
y_column: str | None,
neighbors_column: str | None,
Expand Down Expand Up @@ -280,6 +287,7 @@ def main(
neighbors=new_neighbors_column,
model=model,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
umap_args=umap_args,
)
elif image is not None:
Expand All @@ -291,6 +299,7 @@ def main(
neighbors=new_neighbors_column,
model=model,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
umap_args=umap_args,
)
else:
Expand Down
29 changes: 25 additions & 4 deletions packages/backend/embedding_atlas/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def _projection_for_texts(
texts: list[str],
model: str | None = None,
trust_remote_code: bool = False,
batch_size: int | None = None,
umap_args: dict = {},
) -> Projection:
if model is None:
Expand All @@ -102,6 +103,7 @@ def _projection_for_texts(
"version": 1,
"texts": texts,
"model": model,
"batch_size": batch_size,
"umap_args": umap_args,
}
)
Expand All @@ -115,11 +117,16 @@ def _projection_for_texts(
# Import on demand.
from sentence_transformers import SentenceTransformer

# Set default batch size if not provided
if batch_size is None:
batch_size = 32
logger.info("Using default batch size of %d for text. Adjust with --batch-size if you encounter memory issues or want to speed up processing.", batch_size)

logger.info("Loading model %s...", model)
transformer = SentenceTransformer(model, trust_remote_code=trust_remote_code)

logger.info("Running embedding for %d texts...", len(texts))
hidden_vectors = transformer.encode(texts)
logger.info("Running embedding for %d texts with batch size %d...", len(texts), batch_size)
hidden_vectors = transformer.encode(texts, batch_size=batch_size)

result = _run_umap(hidden_vectors, umap_args)
Projection.save(cpath, result)
Expand All @@ -130,6 +137,7 @@ def _projection_for_images(
images: list,
model: str | None = None,
trust_remote_code: bool = False,
batch_size: int | None = None,
umap_args: dict = {},
) -> Projection:
if model is None:
Expand All @@ -140,6 +148,7 @@ def _projection_for_images(
"version": 1,
"images": images,
"model": model,
"batch_size": batch_size,
"umap_args": umap_args,
}
)
Expand Down Expand Up @@ -170,9 +179,13 @@ def load_image(value):

pipe = pipeline("image-feature-extraction", model=model, device_map="auto")

logger.info("Running embedding for %d images...", len(images))
# Set default batch size if not provided
if batch_size is None:
batch_size = 16
logger.info("Using default batch size of %d for images. Adjust with --batch-size if you encounter memory issues or want to speed up processing.", batch_size)

logger.info("Running embedding for %d images with batch size %d...", len(images), batch_size)
tensors = []
batch_size = 16

current_batch = []

Expand Down Expand Up @@ -207,6 +220,7 @@ def compute_text_projection(
neighbors: str | None = "neighbors",
model: str | None = None,
trust_remote_code: bool = False,
batch_size: int | None = None,
umap_args: dict = {},
):
"""
Expand All @@ -225,6 +239,8 @@ def compute_text_projection(
model: str, name or path of the SentenceTransformer model to use for embedding.
trust_remote_code: bool, whether to trust and execute remote code when loading
the model from HuggingFace Hub. Default is False.
batch_size: int, batch size for processing embeddings. Larger values use more
memory but may be faster. Default is 32.
umap_args: dict, additional keyword arguments to pass to the UMAP algorithm
(e.g., n_neighbors, min_dist, metric).

Expand All @@ -237,6 +253,7 @@ def compute_text_projection(
list(text_series),
model=model,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
umap_args=umap_args,
)
data_frame[x] = proj.projection[:, 0]
Expand Down Expand Up @@ -314,6 +331,7 @@ def compute_image_projection(
neighbors: str | None = "neighbors",
model: str | None = None,
trust_remote_code: bool = False,
batch_size: int | None = None,
umap_args: dict = {},
):
"""
Expand All @@ -332,6 +350,8 @@ def compute_image_projection(
model: str, name or path of the model to use for embedding.
trust_remote_code: bool, whether to trust and execute remote code when loading
the model from HuggingFace Hub. Default is False.
batch_size: int, batch size for processing images. Larger values use more
memory but may be faster. Default is 16.
umap_args: dict, additional keyword arguments to pass to the UMAP algorithm
(e.g., n_neighbors, min_dist, metric).

Expand All @@ -344,6 +364,7 @@ def compute_image_projection(
list(image_series),
model=model,
trust_remote_code=trust_remote_code,
batch_size=batch_size,
umap_args=umap_args,
)
data_frame[x] = proj.projection[:, 0]
Expand Down