diff --git a/packages/backend/embedding_atlas/cli.py b/packages/backend/embedding_atlas/cli.py index e48c0e7..a32636e 100644 --- a/packages/backend/embedding_atlas/cli.py +++ b/packages/backend/embedding_atlas/cli.py @@ -130,6 +130,12 @@ def find_available_port(start_port: int, max_attempts: int = 10, host="localhost default=False, help="Allow execution of remote code when loading models from Hugging Face Hub.", ) +@click.option( + "--batch-size", + type=int, + default=None, + help="Batch size for processing embeddings (default: 32 for text, 16 for images). Larger values use more memory but may be faster.", +) @click.option( "--x", "x_column", @@ -207,6 +213,7 @@ def main( enable_projection: bool, model: str | None, trust_remote_code: bool, + batch_size: int | None, x_column: str | None, y_column: str | None, neighbors_column: str | None, @@ -280,6 +287,7 @@ def main( neighbors=new_neighbors_column, model=model, trust_remote_code=trust_remote_code, + batch_size=batch_size, umap_args=umap_args, ) elif image is not None: @@ -291,6 +299,7 @@ def main( neighbors=new_neighbors_column, model=model, trust_remote_code=trust_remote_code, + batch_size=batch_size, umap_args=umap_args, ) else: diff --git a/packages/backend/embedding_atlas/projection.py b/packages/backend/embedding_atlas/projection.py index e018e61..6d5b849 100644 --- a/packages/backend/embedding_atlas/projection.py +++ b/packages/backend/embedding_atlas/projection.py @@ -92,6 +92,7 @@ def _projection_for_texts( texts: list[str], model: str | None = None, trust_remote_code: bool = False, + batch_size: int | None = None, umap_args: dict = {}, ) -> Projection: if model is None: @@ -102,6 +103,7 @@ def _projection_for_texts( "version": 1, "texts": texts, "model": model, + "batch_size": batch_size, "umap_args": umap_args, } ) @@ -115,11 +117,16 @@ def _projection_for_texts( # Import on demand. from sentence_transformers import SentenceTransformer + # Set default batch size if not provided + if batch_size is None: + batch_size = 32 + logger.info("Using default batch size of %d for text. Adjust with --batch-size if you encounter memory issues or want to speed up processing.", batch_size) + logger.info("Loading model %s...", model) transformer = SentenceTransformer(model, trust_remote_code=trust_remote_code) - logger.info("Running embedding for %d texts...", len(texts)) - hidden_vectors = transformer.encode(texts) + logger.info("Running embedding for %d texts with batch size %d...", len(texts), batch_size) + hidden_vectors = transformer.encode(texts, batch_size=batch_size) result = _run_umap(hidden_vectors, umap_args) Projection.save(cpath, result) @@ -130,6 +137,7 @@ def _projection_for_images( images: list, model: str | None = None, trust_remote_code: bool = False, + batch_size: int | None = None, umap_args: dict = {}, ) -> Projection: if model is None: @@ -140,6 +148,7 @@ def _projection_for_images( "version": 1, "images": images, "model": model, + "batch_size": batch_size, "umap_args": umap_args, } ) @@ -170,9 +179,13 @@ def load_image(value): pipe = pipeline("image-feature-extraction", model=model, device_map="auto") - logger.info("Running embedding for %d images...", len(images)) + # Set default batch size if not provided + if batch_size is None: + batch_size = 16 + logger.info("Using default batch size of %d for images. Adjust with --batch-size if you encounter memory issues or want to speed up processing.", batch_size) + + logger.info("Running embedding for %d images with batch size %d...", len(images), batch_size) tensors = [] - batch_size = 16 current_batch = [] @@ -207,6 +220,7 @@ def compute_text_projection( neighbors: str | None = "neighbors", model: str | None = None, trust_remote_code: bool = False, + batch_size: int | None = None, umap_args: dict = {}, ): """ @@ -225,6 +239,8 @@ def compute_text_projection( model: str, name or path of the SentenceTransformer model to use for embedding. trust_remote_code: bool, whether to trust and execute remote code when loading the model from HuggingFace Hub. Default is False. + batch_size: int, batch size for processing embeddings. Larger values use more + memory but may be faster. Default is 32. umap_args: dict, additional keyword arguments to pass to the UMAP algorithm (e.g., n_neighbors, min_dist, metric). @@ -237,6 +253,7 @@ def compute_text_projection( list(text_series), model=model, trust_remote_code=trust_remote_code, + batch_size=batch_size, umap_args=umap_args, ) data_frame[x] = proj.projection[:, 0] @@ -314,6 +331,7 @@ def compute_image_projection( neighbors: str | None = "neighbors", model: str | None = None, trust_remote_code: bool = False, + batch_size: int | None = None, umap_args: dict = {}, ): """ @@ -332,6 +350,8 @@ def compute_image_projection( model: str, name or path of the model to use for embedding. trust_remote_code: bool, whether to trust and execute remote code when loading the model from HuggingFace Hub. Default is False. + batch_size: int, batch size for processing images. Larger values use more + memory but may be faster. Default is 16. umap_args: dict, additional keyword arguments to pass to the UMAP algorithm (e.g., n_neighbors, min_dist, metric). @@ -344,6 +364,7 @@ def compute_image_projection( list(image_series), model=model, trust_remote_code=trust_remote_code, + batch_size=batch_size, umap_args=umap_args, ) data_frame[x] = proj.projection[:, 0]