Skip to content

Commit 47f4db8

Browse files
nickprockanakin87
andauthored
added truncate_dim to sentence transformers embedder (#8077)
* added truncate_dim to sentence transformers embedder * Update haystack/components/embedders/sentence_transformers_document_embedder.py Co-authored-by: Stefano Fiorucci <[email protected]> * Update releasenotes/notes/release-note-2b603a123cd36214.yaml Co-authored-by: Stefano Fiorucci <[email protected]> * fixed parameter description * added test for truncation to text embedder * fix format --------- Co-authored-by: Stefano Fiorucci <[email protected]>
1 parent b2aef21 commit 47f4db8

7 files changed

+90
-8
lines changed

haystack/components/embedders/backends/sentence_transformers_backend.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,22 @@ class _SentenceTransformersEmbeddingBackendFactory:
2222

2323
@staticmethod
2424
def get_embedding_backend(
25-
model: str, device: Optional[str] = None, auth_token: Optional[Secret] = None, trust_remote_code: bool = False
25+
model: str,
26+
device: Optional[str] = None,
27+
auth_token: Optional[Secret] = None,
28+
trust_remote_code: bool = False,
29+
truncate_dim: Optional[int] = None,
2630
):
27-
embedding_backend_id = f"{model}{device}{auth_token}"
31+
embedding_backend_id = f"{model}{device}{auth_token}{truncate_dim}"
2832

2933
if embedding_backend_id in _SentenceTransformersEmbeddingBackendFactory._instances:
3034
return _SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id]
3135
embedding_backend = _SentenceTransformersEmbeddingBackend(
32-
model=model, device=device, auth_token=auth_token, trust_remote_code=trust_remote_code
36+
model=model,
37+
device=device,
38+
auth_token=auth_token,
39+
trust_remote_code=trust_remote_code,
40+
truncate_dim=truncate_dim,
3341
)
3442
_SentenceTransformersEmbeddingBackendFactory._instances[embedding_backend_id] = embedding_backend
3543
return embedding_backend
@@ -46,13 +54,15 @@ def __init__(
4654
device: Optional[str] = None,
4755
auth_token: Optional[Secret] = None,
4856
trust_remote_code: bool = False,
57+
truncate_dim: Optional[int] = None,
4958
):
5059
sentence_transformers_import.check()
5160
self.model = SentenceTransformer(
5261
model_name_or_path=model,
5362
device=device,
5463
use_auth_token=auth_token.resolve_value() if auth_token else None,
5564
trust_remote_code=trust_remote_code,
65+
truncate_dim=truncate_dim,
5666
)
5767

5868
def embed(self, data: List[str], **kwargs) -> List[List[float]]:

haystack/components/embedders/sentence_transformers_document_embedder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
meta_fields_to_embed: Optional[List[str]] = None,
4545
embedding_separator: str = "\n",
4646
trust_remote_code: bool = False,
47+
truncate_dim: Optional[int] = None,
4748
):
4849
"""
4950
Create a SentenceTransformersDocumentEmbedder component.
@@ -73,6 +74,10 @@ def __init__(
7374
:param trust_remote_code:
7475
If `False`, only Hugging Face verified model architectures are allowed.
7576
If `True`, custom models and scripts are allowed.
77+
:param truncate_dim:
78+
The dimension to truncate sentence embeddings to. `None` does no truncation.
79+
If the model has not been trained with Matryoshka Representation Learning,
80+
truncation of embeddings can significantly affect performance.
7681
"""
7782

7883
self.model = model
@@ -86,6 +91,7 @@ def __init__(
8691
self.meta_fields_to_embed = meta_fields_to_embed or []
8792
self.embedding_separator = embedding_separator
8893
self.trust_remote_code = trust_remote_code
94+
self.truncate_dim = truncate_dim
8995

9096
def _get_telemetry_data(self) -> Dict[str, Any]:
9197
"""
@@ -113,6 +119,7 @@ def to_dict(self) -> Dict[str, Any]:
113119
meta_fields_to_embed=self.meta_fields_to_embed,
114120
embedding_separator=self.embedding_separator,
115121
trust_remote_code=self.trust_remote_code,
122+
truncate_dim=self.truncate_dim,
116123
)
117124

118125
@classmethod
@@ -141,6 +148,7 @@ def warm_up(self):
141148
device=self.device.to_torch_str(),
142149
auth_token=self.token,
143150
trust_remote_code=self.trust_remote_code,
151+
truncate_dim=self.truncate_dim,
144152
)
145153

146154
@component.output_types(documents=List[Document])

haystack/components/embedders/sentence_transformers_text_embedder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
progress_bar: bool = True,
4545
normalize_embeddings: bool = False,
4646
trust_remote_code: bool = False,
47+
truncate_dim: Optional[int] = None,
4748
):
4849
"""
4950
Create a SentenceTransformersTextEmbedder component.
@@ -71,6 +72,10 @@ def __init__(
7172
:param trust_remote_code:
7273
If `False`, permits only Hugging Face verified model architectures.
7374
If `True`, permits custom models and scripts.
75+
:param truncate_dim:
76+
The dimension to truncate sentence embeddings to. `None` does no truncation.
77+
If the model has not been trained with Matryoshka Representation Learning,
78+
truncation of embeddings can significantly affect performance.
7479
"""
7580

7681
self.model = model
@@ -82,6 +87,7 @@ def __init__(
8287
self.progress_bar = progress_bar
8388
self.normalize_embeddings = normalize_embeddings
8489
self.trust_remote_code = trust_remote_code
90+
self.truncate_dim = truncate_dim
8591

8692
def _get_telemetry_data(self) -> Dict[str, Any]:
8793
"""
@@ -107,6 +113,7 @@ def to_dict(self) -> Dict[str, Any]:
107113
progress_bar=self.progress_bar,
108114
normalize_embeddings=self.normalize_embeddings,
109115
trust_remote_code=self.trust_remote_code,
116+
truncate_dim=self.truncate_dim,
110117
)
111118

112119
@classmethod
@@ -135,6 +142,7 @@ def warm_up(self):
135142
device=self.device.to_torch_str(),
136143
auth_token=self.token,
137144
trust_remote_code=self.trust_remote_code,
145+
truncate_dim=self.truncate_dim,
138146
)
139147

140148
@component.output_types(embedding=List[float])
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
features:
3+
- |
4+
Add `truncate_dim` parameter to Sentence Transformers Embedders, which allows truncating
5+
embeddings. Especially useful for models trained with Matryoshka Representation Learning.

test/components/embedders/test_sentence_transformers_document_embedder.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def test_init_default(self):
2525
assert embedder.meta_fields_to_embed == []
2626
assert embedder.embedding_separator == "\n"
2727
assert embedder.trust_remote_code is False
28+
assert embedder.truncate_dim is None
2829

2930
def test_init_with_parameters(self):
3031
embedder = SentenceTransformersDocumentEmbedder(
@@ -39,6 +40,7 @@ def test_init_with_parameters(self):
3940
meta_fields_to_embed=["test_field"],
4041
embedding_separator=" | ",
4142
trust_remote_code=True,
43+
truncate_dim=256,
4244
)
4345
assert embedder.model == "model"
4446
assert embedder.device == ComponentDevice.from_str("cuda:0")
@@ -51,6 +53,7 @@ def test_init_with_parameters(self):
5153
assert embedder.meta_fields_to_embed == ["test_field"]
5254
assert embedder.embedding_separator == " | "
5355
assert embedder.trust_remote_code
56+
assert embedder.truncate_dim == 256
5457

5558
def test_to_dict(self):
5659
component = SentenceTransformersDocumentEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
@@ -69,6 +72,7 @@ def test_to_dict(self):
6972
"embedding_separator": "\n",
7073
"meta_fields_to_embed": [],
7174
"trust_remote_code": False,
75+
"truncate_dim": None,
7276
},
7377
}
7478

@@ -85,6 +89,7 @@ def test_to_dict_with_custom_init_parameters(self):
8589
meta_fields_to_embed=["meta_field"],
8690
embedding_separator=" - ",
8791
trust_remote_code=True,
92+
truncate_dim=256,
8893
)
8994
data = component.to_dict()
9095

@@ -102,6 +107,7 @@ def test_to_dict_with_custom_init_parameters(self):
102107
"embedding_separator": " - ",
103108
"trust_remote_code": True,
104109
"meta_fields_to_embed": ["meta_field"],
110+
"truncate_dim": 256,
105111
},
106112
}
107113

@@ -118,6 +124,7 @@ def test_from_dict(self):
118124
"embedding_separator": " - ",
119125
"meta_fields_to_embed": ["meta_field"],
120126
"trust_remote_code": True,
127+
"truncate_dim": 256,
121128
}
122129
component = SentenceTransformersDocumentEmbedder.from_dict(
123130
{
@@ -136,6 +143,7 @@ def test_from_dict(self):
136143
assert component.embedding_separator == " - "
137144
assert component.trust_remote_code
138145
assert component.meta_fields_to_embed == ["meta_field"]
146+
assert component.truncate_dim == 256
139147

140148
def test_from_dict_no_default_parameters(self):
141149
component = SentenceTransformersDocumentEmbedder.from_dict(
@@ -155,6 +163,7 @@ def test_from_dict_no_default_parameters(self):
155163
assert component.embedding_separator == "\n"
156164
assert component.trust_remote_code is False
157165
assert component.meta_fields_to_embed == []
166+
assert component.truncate_dim is None
158167

159168
def test_from_dict_none_device(self):
160169
init_parameters = {
@@ -169,6 +178,7 @@ def test_from_dict_none_device(self):
169178
"embedding_separator": " - ",
170179
"meta_fields_to_embed": ["meta_field"],
171180
"trust_remote_code": True,
181+
"truncate_dim": None,
172182
}
173183
component = SentenceTransformersDocumentEmbedder.from_dict(
174184
{
@@ -187,6 +197,7 @@ def test_from_dict_none_device(self):
187197
assert component.embedding_separator == " - "
188198
assert component.trust_remote_code
189199
assert component.meta_fields_to_embed == ["meta_field"]
200+
assert component.truncate_dim is None
190201

191202
@patch(
192203
"haystack.components.embedders.sentence_transformers_document_embedder._SentenceTransformersEmbeddingBackendFactory"
@@ -198,7 +209,7 @@ def test_warmup(self, mocked_factory):
198209
mocked_factory.get_embedding_backend.assert_not_called()
199210
embedder.warm_up()
200211
mocked_factory.get_embedding_backend.assert_called_once_with(
201-
model="model", device="cpu", auth_token=None, trust_remote_code=False
212+
model="model", device="cpu", auth_token=None, trust_remote_code=False, truncate_dim=None
202213
)
203214

204215
@patch(

test/components/embedders/test_sentence_transformers_embedding_backend.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,18 @@ def test_factory_behavior(mock_sentence_transformer):
2828
@patch("haystack.components.embedders.backends.sentence_transformers_backend.SentenceTransformer")
2929
def test_model_initialization(mock_sentence_transformer):
3030
_SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
31-
model="model", device="cpu", auth_token=Secret.from_token("fake-api-token"), trust_remote_code=True
31+
model="model",
32+
device="cpu",
33+
auth_token=Secret.from_token("fake-api-token"),
34+
trust_remote_code=True,
35+
truncate_dim=256,
3236
)
3337
mock_sentence_transformer.assert_called_once_with(
34-
model_name_or_path="model", device="cpu", use_auth_token="fake-api-token", trust_remote_code=True
38+
model_name_or_path="model",
39+
device="cpu",
40+
use_auth_token="fake-api-token",
41+
trust_remote_code=True,
42+
truncate_dim=256,
3543
)
3644

3745

test/components/embedders/test_sentence_transformers_text_embedder.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_init_default(self):
2222
assert embedder.progress_bar is True
2323
assert embedder.normalize_embeddings is False
2424
assert embedder.trust_remote_code is False
25+
assert embedder.truncate_dim is None
2526

2627
def test_init_with_parameters(self):
2728
embedder = SentenceTransformersTextEmbedder(
@@ -34,6 +35,7 @@ def test_init_with_parameters(self):
3435
progress_bar=False,
3536
normalize_embeddings=True,
3637
trust_remote_code=True,
38+
truncate_dim=256,
3739
)
3840
assert embedder.model == "model"
3941
assert embedder.device == ComponentDevice.from_str("cuda:0")
@@ -43,7 +45,8 @@ def test_init_with_parameters(self):
4345
assert embedder.batch_size == 64
4446
assert embedder.progress_bar is False
4547
assert embedder.normalize_embeddings is True
46-
assert embedder.trust_remote_code
48+
assert embedder.trust_remote_code is True
49+
assert embedder.truncate_dim == 256
4750

4851
def test_to_dict(self):
4952
component = SentenceTransformersTextEmbedder(model="model", device=ComponentDevice.from_str("cpu"))
@@ -60,6 +63,7 @@ def test_to_dict(self):
6063
"progress_bar": True,
6164
"normalize_embeddings": False,
6265
"trust_remote_code": False,
66+
"truncate_dim": None,
6367
},
6468
}
6569

@@ -74,6 +78,7 @@ def test_to_dict_with_custom_init_parameters(self):
7478
progress_bar=False,
7579
normalize_embeddings=True,
7680
trust_remote_code=True,
81+
truncate_dim=256,
7782
)
7883
data = component.to_dict()
7984
assert data == {
@@ -88,6 +93,7 @@ def test_to_dict_with_custom_init_parameters(self):
8893
"progress_bar": False,
8994
"normalize_embeddings": True,
9095
"trust_remote_code": True,
96+
"truncate_dim": 256,
9197
},
9298
}
9399

@@ -109,6 +115,7 @@ def test_from_dict(self):
109115
"progress_bar": True,
110116
"normalize_embeddings": False,
111117
"trust_remote_code": False,
118+
"truncate_dim": None,
112119
},
113120
}
114121
component = SentenceTransformersTextEmbedder.from_dict(data)
@@ -121,6 +128,7 @@ def test_from_dict(self):
121128
assert component.progress_bar is True
122129
assert component.normalize_embeddings is False
123130
assert component.trust_remote_code is False
131+
assert component.truncate_dim is None
124132

125133
def test_from_dict_no_default_parameters(self):
126134
data = {
@@ -137,6 +145,7 @@ def test_from_dict_no_default_parameters(self):
137145
assert component.progress_bar is True
138146
assert component.normalize_embeddings is False
139147
assert component.trust_remote_code is False
148+
assert component.truncate_dim is None
140149

141150
def test_from_dict_none_device(self):
142151
data = {
@@ -151,6 +160,7 @@ def test_from_dict_none_device(self):
151160
"progress_bar": True,
152161
"normalize_embeddings": False,
153162
"trust_remote_code": False,
163+
"truncate_dim": 256,
154164
},
155165
}
156166
component = SentenceTransformersTextEmbedder.from_dict(data)
@@ -163,6 +173,7 @@ def test_from_dict_none_device(self):
163173
assert component.progress_bar is True
164174
assert component.normalize_embeddings is False
165175
assert component.trust_remote_code is False
176+
assert component.truncate_dim == 256
166177

167178
@patch(
168179
"haystack.components.embedders.sentence_transformers_text_embedder._SentenceTransformersEmbeddingBackendFactory"
@@ -172,7 +183,7 @@ def test_warmup(self, mocked_factory):
172183
mocked_factory.get_embedding_backend.assert_not_called()
173184
embedder.warm_up()
174185
mocked_factory.get_embedding_backend.assert_called_once_with(
175-
model="model", device="cpu", auth_token=None, trust_remote_code=False
186+
model="model", device="cpu", auth_token=None, trust_remote_code=False, truncate_dim=None
176187
)
177188

178189
@patch(
@@ -206,3 +217,24 @@ def test_run_wrong_input_format(self):
206217

207218
with pytest.raises(TypeError, match="SentenceTransformersTextEmbedder expects a string as input"):
208219
embedder.run(text=list_integers_input)
220+
221+
@pytest.mark.integration
222+
def test_run_trunc(self):
223+
"""
224+
sentence-transformers/paraphrase-albert-small-v2 maps sentences & paragraphs to a 768 dimensional dense vector space
225+
"""
226+
checkpoint = "sentence-transformers/paraphrase-albert-small-v2"
227+
text = "a nice text to embed"
228+
229+
embedder_def = SentenceTransformersTextEmbedder(model=checkpoint)
230+
embedder_def.warm_up()
231+
result_def = embedder_def.run(text=text)
232+
embedding_def = result_def["embedding"]
233+
234+
embedder_trunc = SentenceTransformersTextEmbedder(model=checkpoint, truncate_dim=128)
235+
embedder_trunc.warm_up()
236+
result_trunc = embedder_trunc.run(text=text)
237+
embedding_trunc = result_trunc["embedding"]
238+
239+
assert len(embedding_def) == 768
240+
assert len(embedding_trunc) == 128

0 commit comments

Comments
 (0)