Skip to content

Commit 25d3520

Browse files
vblagojesjrlAmnah199dfokina
authored
feat: Add AnswerJoiner new component (#8122)
* Initial AnswerJoiner * Initial tests * Add release note * Resove mypy warning * Add custom join function * Serialize custom join function * Handle all Answer types, add integration test, improve pydoc * Make fixes * Add to API docs * Add more tests * Update haystack/components/joiners/answer_joiner.py Co-authored-by: Amna Mubashar <[email protected]> * Update docstrings and release notes * update docstrings --------- Co-authored-by: Sebastian Husch Lee <[email protected]> Co-authored-by: Sebastian Husch Lee <[email protected]> Co-authored-by: Amna Mubashar <[email protected]> Co-authored-by: Darja Fokina <[email protected]>
1 parent 3d1ad10 commit 25d3520

File tree

5 files changed

+321
-2
lines changed

5 files changed

+321
-2
lines changed

docs/pydoc/config/joiners_api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
loaders:
22
- type: haystack_pydoc_tools.loaders.CustomPythonLoader
33
search_path: [../../../haystack/components/joiners]
4-
modules: ["document_joiner", "branch"]
4+
modules: ["document_joiner", "branch", "answer_joiner"]
55
ignore_when_discovered: ["__init__"]
66
processors:
77
- type: filter

haystack/components/joiners/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5+
from .answer_joiner import AnswerJoiner
56
from .branch import BranchJoiner
67
from .document_joiner import DocumentJoiner
78

8-
__all__ = ["DocumentJoiner", "BranchJoiner"]
9+
__all__ = ["DocumentJoiner", "BranchJoiner", "AnswerJoiner"]
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import itertools
6+
from enum import Enum
7+
from math import inf
8+
from typing import Any, Callable, Dict, List, Optional, Union
9+
10+
from haystack import component, default_from_dict, default_to_dict, logging
11+
from haystack.core.component.types import Variadic
12+
from haystack.dataclasses.answer import ExtractedAnswer, ExtractedTableAnswer, GeneratedAnswer
13+
14+
AnswerType = Union[GeneratedAnswer, ExtractedTableAnswer, ExtractedAnswer]
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
class JoinMode(Enum):
20+
"""
21+
Enum for AnswerJoiner join modes.
22+
"""
23+
24+
CONCATENATE = "concatenate"
25+
26+
def __str__(self):
27+
return self.value
28+
29+
@staticmethod
30+
def from_str(string: str) -> "JoinMode":
31+
"""
32+
Convert a string to a JoinMode enum.
33+
"""
34+
enum_map = {e.value: e for e in JoinMode}
35+
mode = enum_map.get(string)
36+
if mode is None:
37+
msg = f"Unknown join mode '{string}'. Supported modes in AnswerJoiner are: {list(enum_map.keys())}"
38+
raise ValueError(msg)
39+
return mode
40+
41+
42+
@component
43+
class AnswerJoiner:
44+
"""
45+
Merges multiple lists of `Answer` objects into a single list.
46+
47+
Use this component to combine answers from different Generators into a single list.
48+
Currently, the component supports only one join mode: `CONCATENATE`.
49+
This mode concatenates multiple lists of answers into a single list.
50+
51+
### Usage example
52+
53+
In this example, AnswerJoiner merges answers from two different Generators:
54+
55+
```python
56+
from haystack.components.builders import AnswerBuilder
57+
from haystack.components.joiners import AnswerJoiner
58+
59+
from haystack.core.pipeline import Pipeline
60+
61+
from haystack.components.generators.chat import OpenAIChatGenerator
62+
from haystack.dataclasses import ChatMessage
63+
64+
65+
query = "What's Natural Language Processing?"
66+
messages = [ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."),
67+
ChatMessage.from_user(query)]
68+
69+
pipe = Pipeline()
70+
pipe.add_component("gpt-4o", OpenAIChatGenerator(model="gpt-4o"))
71+
pipe.add_component("llama", OpenAIChatGenerator(model="gpt-3.5-turbo"))
72+
pipe.add_component("aba", AnswerBuilder())
73+
pipe.add_component("abb", AnswerBuilder())
74+
pipe.add_component("joiner", AnswerJoiner())
75+
76+
pipe.connect("gpt-4o.replies", "aba")
77+
pipe.connect("llama.replies", "abb")
78+
pipe.connect("aba.answers", "joiner")
79+
pipe.connect("abb.answers", "joiner")
80+
81+
results = pipe.run(data={"gpt-4o": {"messages": messages},
82+
"llama": {"messages": messages},
83+
"aba": {"query": query},
84+
"abb": {"query": query}})
85+
```
86+
"""
87+
88+
def __init__(
89+
self,
90+
join_mode: Union[str, JoinMode] = JoinMode.CONCATENATE,
91+
top_k: Optional[int] = None,
92+
sort_by_score: bool = False,
93+
):
94+
"""
95+
Creates an AnswerJoiner component.
96+
97+
:param join_mode:
98+
Specifies the join mode to use. Available modes:
99+
- `concatenate`: Concatenates multiple lists of Answers into a single list.
100+
:param top_k:
101+
The maximum number of Answers to return.
102+
:param sort_by_score:
103+
If `True`, sorts the documents by score in descending order.
104+
If a document has no score, it is handled as if its score is -infinity.
105+
"""
106+
if isinstance(join_mode, str):
107+
join_mode = JoinMode.from_str(join_mode)
108+
join_mode_functions: Dict[JoinMode, Callable[[List[List[AnswerType]]], List[AnswerType]]] = {
109+
JoinMode.CONCATENATE: self._concatenate
110+
}
111+
self.join_mode_function: Callable[[List[List[AnswerType]]], List[AnswerType]] = join_mode_functions[join_mode]
112+
self.join_mode = join_mode
113+
self.top_k = top_k
114+
self.sort_by_score = sort_by_score
115+
116+
@component.output_types(answers=List[AnswerType])
117+
def run(self, answers: Variadic[List[AnswerType]], top_k: Optional[int] = None):
118+
"""
119+
Joins multiple lists of Answers into a single list depending on the `join_mode` parameter.
120+
121+
:param answers:
122+
Nested list of Answers to be merged.
123+
124+
:param top_k:
125+
The maximum number of Answers to return. Overrides the instance's `top_k` if provided.
126+
127+
:returns:
128+
A dictionary with the following keys:
129+
- `answers`: Merged list of Answers
130+
"""
131+
answers_list = list(answers)
132+
join_function = self.join_mode_function
133+
output_answers: List[AnswerType] = join_function(answers_list)
134+
135+
if self.sort_by_score:
136+
output_answers = sorted(
137+
output_answers, key=lambda answer: answer.score if hasattr(answer, "score") else -inf, reverse=True
138+
)
139+
140+
top_k = top_k or self.top_k
141+
if top_k:
142+
output_answers = output_answers[:top_k]
143+
return {"answers": output_answers}
144+
145+
def _concatenate(self, answer_lists: List[List[AnswerType]]) -> List[AnswerType]:
146+
"""
147+
Concatenate multiple lists of Answers, flattening them into a single list and sorting by score.
148+
149+
:param answer_lists: List of lists of Answers to be flattened.
150+
"""
151+
return list(itertools.chain.from_iterable(answer_lists))
152+
153+
def to_dict(self) -> Dict[str, Any]:
154+
"""
155+
Serializes the component to a dictionary.
156+
157+
:returns:
158+
Dictionary with serialized data.
159+
"""
160+
return default_to_dict(self, join_mode=str(self.join_mode), top_k=self.top_k, sort_by_score=self.sort_by_score)
161+
162+
@classmethod
163+
def from_dict(cls, data: Dict[str, Any]) -> "AnswerJoiner":
164+
"""
165+
Deserializes the component from a dictionary.
166+
167+
:param data:
168+
The dictionary to deserialize from.
169+
:returns:
170+
The deserialized component.
171+
"""
172+
return default_from_dict(cls, data)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
features:
3+
- |
4+
Introduced a new AnswerJoiner component that allows joining multiple lists of Answers into a single list using
5+
the Concatenate join mode.
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
import os
5+
6+
import pytest
7+
8+
from haystack.components.builders import AnswerBuilder
9+
10+
from haystack import Document, Pipeline
11+
from haystack.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, ExtractedTableAnswer
12+
from haystack.components.generators.chat import OpenAIChatGenerator
13+
from haystack.components.joiners.answer_joiner import AnswerJoiner, JoinMode
14+
from haystack.dataclasses import ChatMessage
15+
16+
17+
class TestAnswerJoiner:
18+
def test_init(self):
19+
joiner = AnswerJoiner()
20+
assert joiner.join_mode == JoinMode.CONCATENATE
21+
assert joiner.top_k is None
22+
assert joiner.sort_by_score is False
23+
24+
def test_init_with_custom_parameters(self):
25+
joiner = AnswerJoiner(join_mode="concatenate", top_k=5, sort_by_score=True)
26+
assert joiner.join_mode == JoinMode.CONCATENATE
27+
assert joiner.top_k == 5
28+
assert joiner.sort_by_score is True
29+
30+
def test_to_dict(self):
31+
joiner = AnswerJoiner()
32+
data = joiner.to_dict()
33+
assert data == {
34+
"type": "haystack.components.joiners.answer_joiner.AnswerJoiner",
35+
"init_parameters": {"join_mode": "concatenate", "top_k": None, "sort_by_score": False},
36+
}
37+
38+
def test_to_from_dict_custom_parameters(self):
39+
joiner = AnswerJoiner("concatenate", top_k=5, sort_by_score=True)
40+
data = joiner.to_dict()
41+
assert data == {
42+
"type": "haystack.components.joiners.answer_joiner.AnswerJoiner",
43+
"init_parameters": {"join_mode": "concatenate", "top_k": 5, "sort_by_score": True},
44+
}
45+
46+
deserialized_joiner = AnswerJoiner.from_dict(data)
47+
assert deserialized_joiner.join_mode == JoinMode.CONCATENATE
48+
assert deserialized_joiner.top_k == 5
49+
assert deserialized_joiner.sort_by_score is True
50+
51+
def test_from_dict(self):
52+
data = {"type": "haystack.components.joiners.answer_joiner.AnswerJoiner", "init_parameters": {}}
53+
answer_joiner = AnswerJoiner.from_dict(data)
54+
assert answer_joiner.join_mode == JoinMode.CONCATENATE
55+
assert answer_joiner.top_k is None
56+
assert answer_joiner.sort_by_score is False
57+
58+
def test_from_dict_customs_parameters(self):
59+
data = {
60+
"type": "haystack.components.joiners.answer_joiner.AnswerJoiner",
61+
"init_parameters": {"join_mode": "concatenate", "top_k": 5, "sort_by_score": True},
62+
}
63+
answer_joiner = AnswerJoiner.from_dict(data)
64+
assert answer_joiner.join_mode == JoinMode.CONCATENATE
65+
assert answer_joiner.top_k == 5
66+
assert answer_joiner.sort_by_score is True
67+
68+
def test_empty_list(self):
69+
joiner = AnswerJoiner()
70+
result = joiner.run([])
71+
assert result == {"answers": []}
72+
73+
def test_list_of_empty_lists(self):
74+
joiner = AnswerJoiner()
75+
result = joiner.run([[], []])
76+
assert result == {"answers": []}
77+
78+
def test_list_of_single_answer(self):
79+
joiner = AnswerJoiner()
80+
answers = [
81+
GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")]),
82+
GeneratedAnswer(query="b", data="b", meta={}, documents=[Document(content="b")]),
83+
GeneratedAnswer(query="c", data="c", meta={}, documents=[Document(content="c")]),
84+
]
85+
result = joiner.run([answers])
86+
assert result == {"answers": answers}
87+
88+
def test_two_lists_of_generated_answers(self):
89+
joiner = AnswerJoiner()
90+
answers1 = [GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")])]
91+
answers2 = [GeneratedAnswer(query="d", data="d", meta={}, documents=[Document(content="d")])]
92+
result = joiner.run([answers1, answers2])
93+
assert result == {"answers": answers1 + answers2}
94+
95+
def test_multiple_lists_of_mixed_answers(self):
96+
joiner = AnswerJoiner()
97+
answers1 = [GeneratedAnswer(query="a", data="a", meta={}, documents=[Document(content="a")])]
98+
answers2 = [ExtractedAnswer(query="d", score=0.9, meta={}, document=Document(content="d"))]
99+
answers3 = [ExtractedTableAnswer(query="e", score=0.7, meta={}, document=Document(content="e"))]
100+
answers4 = [GeneratedAnswer(query="f", data="f", meta={}, documents=[Document(content="f")])]
101+
all_answers = answers1 + answers2 + answers3 + answers4 # type: ignore
102+
result = joiner.run([answers1, answers2, answers3, answers4])
103+
assert result == {"answers": all_answers}
104+
105+
def test_unsupported_join_mode(self):
106+
unsupported_mode = "unsupported_mode"
107+
with pytest.raises(ValueError):
108+
AnswerJoiner(join_mode=unsupported_mode)
109+
110+
@pytest.mark.skipif(not os.environ.get("OPENAI_API_KEY", ""), reason="Needs OPENAI_API_KEY to run this test.")
111+
@pytest.mark.integration
112+
def test_with_pipeline(self):
113+
query = "What's Natural Language Processing?"
114+
messages = [
115+
ChatMessage.from_system("You are a helpful, respectful and honest assistant. Be super concise."),
116+
ChatMessage.from_user(query),
117+
]
118+
119+
pipe = Pipeline()
120+
pipe.add_component("gpt-4o", OpenAIChatGenerator(model="gpt-4o"))
121+
pipe.add_component("llama", OpenAIChatGenerator(model="gpt-3.5-turbo"))
122+
pipe.add_component("aba", AnswerBuilder())
123+
pipe.add_component("abb", AnswerBuilder())
124+
pipe.add_component("joiner", AnswerJoiner())
125+
126+
pipe.connect("gpt-4o.replies", "aba")
127+
pipe.connect("llama.replies", "abb")
128+
pipe.connect("aba.answers", "joiner")
129+
pipe.connect("abb.answers", "joiner")
130+
131+
results = pipe.run(
132+
data={
133+
"gpt-4o": {"messages": messages},
134+
"llama": {"messages": messages},
135+
"aba": {"query": query},
136+
"abb": {"query": query},
137+
}
138+
)
139+
140+
assert "joiner" in results
141+
assert len(results["joiner"]["answers"]) == 2

0 commit comments

Comments
 (0)