Skip to content

Commit caad819

Browse files
authored
feat: add batch size parameter to CLI for Sentence Transformers (#24)
1 parent b066f03 commit caad819

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

packages/backend/embedding_atlas/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,12 @@ def find_available_port(start_port: int, max_attempts: int = 10, host="localhost
130130
default=False,
131131
help="Allow execution of remote code when loading models from Hugging Face Hub.",
132132
)
133+
@click.option(
134+
"--batch-size",
135+
type=int,
136+
default=None,
137+
help="Batch size for processing embeddings (default: 32 for text, 16 for images). Larger values use more memory but may be faster.",
138+
)
133139
@click.option(
134140
"--x",
135141
"x_column",
@@ -207,6 +213,7 @@ def main(
207213
enable_projection: bool,
208214
model: str | None,
209215
trust_remote_code: bool,
216+
batch_size: int | None,
210217
x_column: str | None,
211218
y_column: str | None,
212219
neighbors_column: str | None,
@@ -280,6 +287,7 @@ def main(
280287
neighbors=new_neighbors_column,
281288
model=model,
282289
trust_remote_code=trust_remote_code,
290+
batch_size=batch_size,
283291
umap_args=umap_args,
284292
)
285293
elif image is not None:
@@ -291,6 +299,7 @@ def main(
291299
neighbors=new_neighbors_column,
292300
model=model,
293301
trust_remote_code=trust_remote_code,
302+
batch_size=batch_size,
294303
umap_args=umap_args,
295304
)
296305
else:

packages/backend/embedding_atlas/projection.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _projection_for_texts(
9292
texts: list[str],
9393
model: str | None = None,
9494
trust_remote_code: bool = False,
95+
batch_size: int | None = None,
9596
umap_args: dict = {},
9697
) -> Projection:
9798
if model is None:
@@ -102,6 +103,7 @@ def _projection_for_texts(
102103
"version": 1,
103104
"texts": texts,
104105
"model": model,
106+
"batch_size": batch_size,
105107
"umap_args": umap_args,
106108
}
107109
)
@@ -115,11 +117,16 @@ def _projection_for_texts(
115117
# Import on demand.
116118
from sentence_transformers import SentenceTransformer
117119

120+
# Set default batch size if not provided
121+
if batch_size is None:
122+
batch_size = 32
123+
logger.info("Using default batch size of %d for text. Adjust with --batch-size if you encounter memory issues or want to speed up processing.", batch_size)
124+
118125
logger.info("Loading model %s...", model)
119126
transformer = SentenceTransformer(model, trust_remote_code=trust_remote_code)
120127

121-
logger.info("Running embedding for %d texts...", len(texts))
122-
hidden_vectors = transformer.encode(texts)
128+
logger.info("Running embedding for %d texts with batch size %d...", len(texts), batch_size)
129+
hidden_vectors = transformer.encode(texts, batch_size=batch_size)
123130

124131
result = _run_umap(hidden_vectors, umap_args)
125132
Projection.save(cpath, result)
@@ -130,6 +137,7 @@ def _projection_for_images(
130137
images: list,
131138
model: str | None = None,
132139
trust_remote_code: bool = False,
140+
batch_size: int | None = None,
133141
umap_args: dict = {},
134142
) -> Projection:
135143
if model is None:
@@ -140,6 +148,7 @@ def _projection_for_images(
140148
"version": 1,
141149
"images": images,
142150
"model": model,
151+
"batch_size": batch_size,
143152
"umap_args": umap_args,
144153
}
145154
)
@@ -170,9 +179,13 @@ def load_image(value):
170179

171180
pipe = pipeline("image-feature-extraction", model=model, device_map="auto")
172181

173-
logger.info("Running embedding for %d images...", len(images))
182+
# Set default batch size if not provided
183+
if batch_size is None:
184+
batch_size = 16
185+
logger.info("Using default batch size of %d for images. Adjust with --batch-size if you encounter memory issues or want to speed up processing.", batch_size)
186+
187+
logger.info("Running embedding for %d images with batch size %d...", len(images), batch_size)
174188
tensors = []
175-
batch_size = 16
176189

177190
current_batch = []
178191

@@ -207,6 +220,7 @@ def compute_text_projection(
207220
neighbors: str | None = "neighbors",
208221
model: str | None = None,
209222
trust_remote_code: bool = False,
223+
batch_size: int | None = None,
210224
umap_args: dict = {},
211225
):
212226
"""
@@ -225,6 +239,8 @@ def compute_text_projection(
225239
model: str, name or path of the SentenceTransformer model to use for embedding.
226240
trust_remote_code: bool, whether to trust and execute remote code when loading
227241
the model from HuggingFace Hub. Default is False.
242+
batch_size: int, batch size for processing embeddings. Larger values use more
243+
memory but may be faster. Default is 32.
228244
umap_args: dict, additional keyword arguments to pass to the UMAP algorithm
229245
(e.g., n_neighbors, min_dist, metric).
230246
@@ -237,6 +253,7 @@ def compute_text_projection(
237253
list(text_series),
238254
model=model,
239255
trust_remote_code=trust_remote_code,
256+
batch_size=batch_size,
240257
umap_args=umap_args,
241258
)
242259
data_frame[x] = proj.projection[:, 0]
@@ -314,6 +331,7 @@ def compute_image_projection(
314331
neighbors: str | None = "neighbors",
315332
model: str | None = None,
316333
trust_remote_code: bool = False,
334+
batch_size: int | None = None,
317335
umap_args: dict = {},
318336
):
319337
"""
@@ -332,6 +350,8 @@ def compute_image_projection(
332350
model: str, name or path of the model to use for embedding.
333351
trust_remote_code: bool, whether to trust and execute remote code when loading
334352
the model from HuggingFace Hub. Default is False.
353+
batch_size: int, batch size for processing images. Larger values use more
354+
memory but may be faster. Default is 16.
335355
umap_args: dict, additional keyword arguments to pass to the UMAP algorithm
336356
(e.g., n_neighbors, min_dist, metric).
337357
@@ -344,6 +364,7 @@ def compute_image_projection(
344364
list(image_series),
345365
model=model,
346366
trust_remote_code=trust_remote_code,
367+
batch_size=batch_size,
347368
umap_args=umap_args,
348369
)
349370
data_frame[x] = proj.projection[:, 0]

0 commit comments

Comments
 (0)